From 1477ba0f9a5ba0a7b3d9291db1c358808ffc2060 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 27 Nov 2023 16:45:02 -0800 Subject: [PATCH 01/35] Add multi-LoRA support --------- Co-authored-by: Chen Shen Co-authored-by: Shreyas Krishnaswamy Co-authored-by: Avnish Narayan --- csrc/punica/LICENSE | 217 +++ csrc/punica/bgmv/bgmv_all.cu | 21 + csrc/punica/bgmv/bgmv_config.h | 53 + csrc/punica/bgmv/bgmv_impl.cuh | 294 ++++ csrc/punica/bgmv/vec_dtypes.cuh | 1324 +++++++++++++++++ csrc/punica/punica_ops.cc | 563 +++++++ setup.py | 60 +- tests/lora/__init__.py | 0 tests/lora/conftest.py | 139 ++ tests/lora/test_layers.py | 697 +++++++++ tests/lora/test_llama.py | 141 ++ tests/lora/test_lora.py | 224 +++ tests/lora/test_lora_manager.py | 473 ++++++ tests/lora/test_punica.py | 196 +++ tests/lora/test_tokenizer.py | 69 + tests/lora/test_utils.py | 172 +++ tests/lora/test_worker.py | 56 + tests/lora/utils.py | 88 ++ vllm/config.py | 31 + vllm/core/scheduler.py | 21 +- vllm/engine/arg_utils.py | 42 +- vllm/engine/async_llm_engine.py | 80 +- vllm/engine/llm_engine.py | 87 +- vllm/entrypoints/llm.py | 16 +- vllm/lora/__init__.py | 0 vllm/lora/layers.py | 1002 +++++++++++++ vllm/lora/lora.py | 120 ++ vllm/lora/models.py | 666 +++++++++ vllm/lora/punica.py | 173 +++ vllm/lora/request.py | 19 + vllm/lora/utils.py | 39 + vllm/lora/worker_manager.py | 266 ++++ vllm/model_executor/layers/sampler.py | 33 +- .../layers/vocab_parallel_embedding.py | 14 +- vllm/model_executor/model_loader.py | 14 +- vllm/model_executor/models/llama.py | 23 +- vllm/model_executor/models/mistral.py | 25 +- vllm/outputs.py | 19 +- vllm/sequence.py | 14 + vllm/transformers_utils/tokenizer.py | 82 + vllm/utils.py | 90 ++ vllm/worker/worker.py | 134 +- 42 files changed, 7713 insertions(+), 84 deletions(-) create mode 100644 csrc/punica/LICENSE create mode 100644 csrc/punica/bgmv/bgmv_all.cu create mode 100644 csrc/punica/bgmv/bgmv_config.h create mode 100644 csrc/punica/bgmv/bgmv_impl.cuh create mode 100644 csrc/punica/bgmv/vec_dtypes.cuh create mode 100644 csrc/punica/punica_ops.cc create mode 100644 tests/lora/__init__.py create mode 100644 tests/lora/conftest.py create mode 100644 tests/lora/test_layers.py create mode 100644 tests/lora/test_llama.py create mode 100644 tests/lora/test_lora.py create mode 100644 tests/lora/test_lora_manager.py create mode 100644 tests/lora/test_punica.py create mode 100644 tests/lora/test_tokenizer.py create mode 100644 tests/lora/test_utils.py create mode 100644 tests/lora/test_worker.py create mode 100644 tests/lora/utils.py create mode 100644 vllm/lora/__init__.py create mode 100644 vllm/lora/layers.py create mode 100644 vllm/lora/lora.py create mode 100644 vllm/lora/models.py create mode 100644 vllm/lora/punica.py create mode 100644 vllm/lora/request.py create mode 100644 vllm/lora/utils.py create mode 100644 vllm/lora/worker_manager.py diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE new file mode 100644 index 0000000000000..a46e2cdcadf7d --- /dev/null +++ b/csrc/punica/LICENSE @@ -0,0 +1,217 @@ +Contains code from https://github.com/punica-ai/punica + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ + +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache-2.0 +* third_party/nvbench (with LLVM exception) +* third_party/flashinfer + +BSD-3-Clause: +* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu new file mode 100644 index 0000000000000..bc86416701f13 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -0,0 +1,21 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h new file mode 100644 index 0000000000000..3fd56b685be13 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_config.h @@ -0,0 +1,53 @@ +#pragma once + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale); + +// clang-format off + +#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, narrow, 128) \ + f(in_T, out_T, W_T, narrow, 256) \ + f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 1024) \ + f(in_T, out_T, W_T, narrow, 1280) \ + f(in_T, out_T, W_T, narrow, 1728) \ + f(in_T, out_T, W_T, narrow, 1792) \ + f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2560) \ + f(in_T, out_T, W_T, narrow, 2752) \ + f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3456) \ + f(in_T, out_T, W_T, narrow, 3584) \ + f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 5120) \ + f(in_T, out_T, W_T, narrow, 5504) \ + f(in_T, out_T, W_T, narrow, 6912) \ + f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 8192) \ + f(in_T, out_T, W_T, narrow, 9216) \ + f(in_T, out_T, W_T, narrow, 10240) \ + f(in_T, out_T, W_T, narrow, 11008) \ + f(in_T, out_T, W_T, narrow, 12288) \ + f(in_T, out_T, W_T, narrow, 13824) \ + f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 16384) \ + f(in_T, out_T, W_T, narrow, 20480) \ + f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 32000) \ + f(in_T, out_T, W_T, narrow, 32256) \ + f(in_T, out_T, W_T, narrow, 36864) \ + f(in_T, out_T, W_T, narrow, 49152) \ + +#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + +// clang-format on diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh new file mode 100644 index 0000000000000..995de26e8bada --- /dev/null +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -0,0 +1,294 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "vec_dtypes.cuh" + +namespace cg = cooperative_groups; + +// nthrs = (32, 4) +template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t j = blockIdx.x; + constexpr size_t num_pipeline_stages = 2; + constexpr size_t tile_size = tx * ty * vec_size; + __shared__ W_T W_shared[num_pipeline_stages * tile_size]; + __shared__ in_T X_shared[num_pipeline_stages * tile_size]; + __shared__ float y_warpwise[ty]; + + size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + auto pipe = cuda::make_pipeline(); + + // pipeline load W/X and compute WX; + pipe.producer_acquire(); + cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + pipe.producer_commit(); + size_t copy_idx, compute_idx; + float y = 0.f; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; + ++tile_idx) { + copy_idx = tile_idx % num_pipeline_stages; + // pipeline stage: async copy W fragment + pipe.producer_acquire(); + if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { + cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + } + pipe.producer_commit(); + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // pipeline stage: compute WX + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = sum; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + } + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // final pipeline stage + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = + ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) + ? sum + : 0.f; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + + // write Y; + if (block.thread_rank() == 0) { + Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); + } +} + +// nthrs = (2, 16, 4) +template +__global__ void +bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t tile_idx = blockIdx.x; + + // load X; + vec_t x_vec; + x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); + + // load W; + vec_t w_vec; + w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + + block.thread_rank() * vec_size); + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } + + cg::thread_block_tile g = cg::tiled_partition(block); +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += g.shfl_down(sum, offset); + } + sum = g.shfl(sum, 0); + + if (threadIdx.x == 0) { + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y] += static_cast(sum); + } +} + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + constexpr size_t vec_size = 8; + constexpr int tz = 4; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if constexpr (feat_in < feat_out) { + static_assert(feat_in % vec_size == 0); + constexpr int tx = feat_in / vec_size; + + static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || + (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || + (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); + + if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { + constexpr int ty = 32 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { + constexpr int ty = 16 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else { + constexpr int ty = 8 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } else { + static_assert(feat_in % (vec_size * 32) == 0 || + feat_in % (vec_size * 16) == 0 || + feat_in % (vec_size * 8) == 0); + + if constexpr (feat_in % (vec_size * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { + constexpr int tx = 16; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } +} + +#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ + template void bgmv_kernel( \ + out_T * __restrict__ Y, const in_T *__restrict__ X, \ + const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ + int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ + int64_t num_layers, int64_t layer_idx, float scale); + +#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ + INST_BGMV(narrow, wide, in_T, out_T, W_T) \ + INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh new file mode 100644 index 0000000000000..cf00d869cf635 --- /dev/null +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -0,0 +1,1324 @@ +#ifndef VEC_DTYPES_CUH_ +#define VEC_DTYPES_CUH_ + +#include +#include +#ifdef FLASHINFER_USE_FP8 +#include +#endif +#include + +#include + +#define FLASHINFER_INLINE \ + inline __attribute__((always_inline)) __device__ __host__ + +template +struct vec_t { + FLASHINFER_INLINE float_t &operator[](size_t i); + FLASHINFER_INLINE const float_t &operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t *ptr); + FLASHINFER_INLINE void store(float_t *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src); + template + FLASHINFER_INLINE void cast_load(const T *ptr); + template + FLASHINFER_INLINE void cast_store(T *ptr) const; + FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); +}; + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = tgt_float_t(src[i]); + } +} + +template +FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, + vec_t &dst) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +FLASHINFER_INLINE void cast_store_impl(const vec_t &src, + tgt_float_t *dst_ptr) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +#ifdef FLASHINFER_USE_FP8 +/******************* vec_t<__nv_fp8_e4m3> *******************/ + +// __nv_fp8_e4m3 x 1 +template <> +struct vec_t<__nv_fp8_e4m3, 1> { + __nv_fp8_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( + __nv_fp8_e4m3 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *dst = *src; +} + +// __nv_fp8_e4m3 x 2 +template <> +struct vec_t<__nv_fp8_e4m3, 2> { + __nv_fp8x2_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x2_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x2_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 4 + +template <> +struct vec_t<__nv_fp8_e4m3, 4> { + __nv_fp8x4_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x4_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x4_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 8 + +template <> +struct vec_t<__nv_fp8_e4m3, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { + ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( + __nv_fp8_e4m3 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 16 or more +template +struct vec_t<__nv_fp8_e4m3, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t<__nv_fp8_e5m2> *******************/ + +// __nv_fp8_e5m2 x 1 +template <> +struct vec_t<__nv_fp8_e5m2, 1> { + __nv_fp8_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( + __nv_fp8_e5m2 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *dst = *src; +} + +// __nv_fp8_e5m2 x 2 +template <> +struct vec_t<__nv_fp8_e5m2, 2> { + __nv_fp8x2_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x2_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x2_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 4 + +template <> +struct vec_t<__nv_fp8_e5m2, 4> { + __nv_fp8x4_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x4_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x4_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 8 + +template <> +struct vec_t<__nv_fp8_e5m2, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { + ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( + __nv_fp8_e5m2 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 16 or more + +template +struct vec_t<__nv_fp8_e5m2, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; +#endif + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *dst = *src; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + data = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((half2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((half2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((half2 *)dst) = *((half2 *)src); +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2 *)(&data.x) = make_half2(val, val); + *(half2 *)(&data.y) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)data)[i]; + } + FLASHINFER_INLINE void fill(half val) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + *(half2 *)(&(data[i].x)) = make_half2(val, val); + *(half2 *)(&(data[i].y)) = make_half2(val, val); + *(half2 *)(&(data[i].z)) = make_half2(val, val); + *(half2 *)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(half *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *dst = *src; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((nv_bfloat162 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((nv_bfloat162 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); +} + +// nv_bfloat16 x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { + data = make_float2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { + data = *((float2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { + *((float2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *((float2 *)dst) = *((float2 *)src); +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(data))[i]; + } + FLASHINFER_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)dst)[i] = ((float4 *)src)[i]; + } + } +}; + +/******************* vec_t type cast *******************/ + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = half(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = + __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = nv_bfloat16(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162 *)(&dst.data))[i] = + __float22bfloat162_rn(((float2 *)(&src.data))[i]); + } + } +} + +#ifdef FLASHINFER_USE_FP8 + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = + __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = + __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +#endif // FLASHINFER_USE_FP8 + +#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc new file mode 100644 index 0000000000000..4ad46e5e1f726 --- /dev/null +++ b/csrc/punica/punica_ops.cc @@ -0,0 +1,563 @@ +#include +#include +#include + +#include + +#include "bgmv/bgmv_config.h" + +namespace { + +//====== utils ====== + +inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, + const char *a_name, const char *b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", + a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, + ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) \ + TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) \ + TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +//====== bgmv ====== + +template +inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, + const int64_t *lora_indices, + uint16_t in_features, uint16_t out_features, + int64_t y_offset, int64_t full_y_size, + int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + switch (pack_u16(in_features, out_features)) { +#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ + case pack_u16(feat_in, feat_out): \ + bgmv_kernel(Y, X, W, lora_indices, y_offset, \ + full_y_size, batch_size, num_layers, \ + layer_idx, scale); \ + break; +#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) + + FOR_BGMV_WIDE_NARROW(CASE, _, _, _) +#undef CASE +#undef CASE_ONESIDE + default: + return false; + } + + return true; +} + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t h_in = x.size(1); + int64_t h_out = y.size(1); + int64_t num_layers = w.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t num_layers = w.size(1); + int64_t full_y_size = y.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +} // namespace + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/setup.py b/setup.py index 2b040e88f0aa4..2e11119043277 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,16 @@ +import contextlib import io import os import re import subprocess -from typing import List, Set import warnings +from pathlib import Path +from typing import List, Set from packaging.version import parse, Version import setuptools import torch +import torch.utils.cpp_extension as torch_cpp_ext from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME ROOT_DIR = os.path.dirname(__file__) @@ -31,6 +34,11 @@ "Cannot find CUDA_HOME. CUDA must be available to build the package.") +def glob(pattern: str): + root = Path(__name__).parent + return [str(p) for p in root.glob(pattern)] + + def get_nvcc_cuda_version(cuda_dir: str) -> Version: """Get the CUDA version from nvcc. @@ -129,19 +137,59 @@ def get_torch_arch_list() -> Set[str]: raise RuntimeError( "CUDA 11.8 or higher is required for compute capability 9.0.") +# Use NVCC threads to parallelize the build. +if nvcc_cuda_version >= Version("11.2"): + num_threads = min(os.cpu_count(), 8) + NVCC_FLAGS += ["--threads", str(num_threads)] + +NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() + # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: num = capability[0] + capability[2] NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"] if capability.endswith("+PTX"): NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"] - -# Use NVCC threads to parallelize the build. -if nvcc_cuda_version >= Version("11.2"): - num_threads = min(os.cpu_count(), 8) - NVCC_FLAGS += ["--threads", str(num_threads)] + if int(capability[0]) >= 8: + NVCC_FLAGS_PUNICA += ["-gencode", f"arch=compute_{num},code=sm_{num}"] + if capability.endswith("+PTX"): + NVCC_FLAGS_PUNICA += [ + "-gencode", f"arch=compute_{num},code=compute_{num}" + ] + +# changes for punica kernels +NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS +REMOVE_NVCC_FLAGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', +] +for flag in REMOVE_NVCC_FLAGS: + with contextlib.suppress(ValueError): + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) ext_modules = [] + +install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1"))) +device_count = torch.cuda.device_count() +for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + install_punica = False + break +if install_punica: + ext_modules.append( + CUDAExtension( + name="vllm._punica_C", + sources=["csrc/punica/punica_ops.cc"] + + glob("csrc/punica/bgmv/*.cu"), + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS_PUNICA, + }, + )) + vllm_extension = CUDAExtension( name="vllm._C", sources=[ diff --git a/tests/lora/__init__.py b/tests/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py new file mode 100644 index 0000000000000..263a2bc9d8156 --- /dev/null +++ b/tests/lora/conftest.py @@ -0,0 +1,139 @@ +import gc +import tempfile +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +import pytest +import ray +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download + +import vllm +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + destroy_model_parallel, initialize_model_parallel) + + +def cleanup(): + destroy_model_parallel() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() + + +@pytest.fixture(autouse=True) +def cleanup_fixture(): + yield + cleanup() + + +@pytest.fixture +def dist_init(): + if not torch.distributed.is_initialized(): + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(1, 1) + yield + cleanup() + + +@pytest.fixture +def dist_init_torch_only(): + if torch.distributed.is_initialized(): + return + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + + +@pytest.fixture +def dummy_model() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", ColumnParallelLinear(50, 10)), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture +def dummy_model_gate_up() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture(scope="session") +def sql_lora_files(): + return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + + +@pytest.fixture +def llama_2_7b_engine_extra_embeddings() -> nn.Module: + cleanup() + get_model_old = get_model + + def get_model_patched(model_config, lora_config=None): + return get_model_old(model_config, LoRAConfig(max_lora_rank=8)) + + with patch("vllm.worker.worker.get_model", get_model_patched): + engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) + yield engine.llm_engine + del engine + cleanup() + + +@pytest.fixture +def llama_2_7b_model_extra_embeddings( + llama_2_7b_engine_extra_embeddings) -> nn.Module: + yield llama_2_7b_engine_extra_embeddings.workers[0].model diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py new file mode 100644 index 0000000000000..fa6a18e8d93d2 --- /dev/null +++ b/tests/lora/test_layers.py @@ -0,0 +1,697 @@ +import pytest +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Dict, Tuple + +import torch +import torch.nn.functional as F + +from vllm.lora.layers import ( + LoRAColumnParallelLinear, + LoRAMergedColumnParallelLinear2Slice, + LoRAQKVParallelLinear, + LoRAVocabParallelEmbedding, + LoRARowParallelLinear, + LoRASampler, + LoRAMapping, + LoRALayer, +) +from vllm.lora.models import LoRA, convert_mapping +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.utils import set_random_seed + +from .utils import DummyLoRAManager + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + id_to_index: List[Optional[int]], + layer: LoRALayer, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRA], Dict[int, List[LoRA]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRA] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. Only useful when + # repeats > 1. + sublora_dict: Dict[int, List[LoRA]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager().init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = LoRA.pack(subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs, index_mapping, prompt_mapping = [], [], [] + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device="cuda")) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device="cuda") * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[512:, :] = 0 + lora_embedding = LoRAVocabParallelEmbedding(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info) + + lora_result = lora_embedding(torch.cat(inputs)) + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = embedding(input_) + after_a = F.embedding( + input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(inputs)) + expected_result = embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.") +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding_data = torch.rand_like(embedding.weight.data) + embedding.weight.data = embedding_data + embedding.weight.data[512:, :] = 0 + expanded_embedding = VocabParallelEmbedding( + 512 + lora_config.lora_extra_vocab_size * max_loras, + 256, + org_num_embeddings=512) + expanded_embedding.weight.data[:512, :] = embedding_data + # We need to deepcopy the embedding as it will be modifed + # in place + lora_embedding = LoRAVocabParallelEmbedding( + deepcopy(expanded_embedding)) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return expanded_embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + expanded_embedding, lora_embedding = create_random_embedding_layer() + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=torch.zeros( + (256, 512 + lora_config.lora_extra_vocab_size)), + generate_embeddings_tensor=256, + ) + + # All embeddings tensors have the same shape. + embeddings_tensors = [ + lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) + ] + embeddings_tensor_len = embeddings_tensors[0].shape[0] + + # Add empty embeddings_tensors for unoccupied lora slots. + for _ in range(max_loras - len(embeddings_tensors)): + embeddings_tensors.append( + torch.zeros(embeddings_tensors[0].shape, device="cuda")) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + # Force some of the inputs to be in the extended embeddings range + # to guarantee that their behavior is tested. + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + embedding_id = lora_id - 1 + input_[-1] = 512 + (embedding_id * embeddings_tensor_len) + original_input_[-1] = 512 + input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = 512 + embeddings_tensor_len - 1 + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + expanded_embedding.weight[512:512 + + (embeddings_tensor_len * + max_loras)] = torch.cat(embeddings_tensors) + + lora_result = lora_embedding(torch.cat(original_inputs)) + + expected_results = [] + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + lora = lora_dict[lora_id] + result = expanded_embedding(input_) + after_a = F.embedding( + original_input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(original_inputs)) + expected_result = expanded_embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_lm_head_sampler(dist_init, num_loras) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_sampler_layer(): + linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, + 1024, 32000) + linear.weight.data = torch.rand_like(linear.weight.data) + linear.weight.data[:, 32000:] = 0 + sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) + lora_sampler = LoRASampler(sampler, 1024, linear.weight.dtype, + linear.weight.device) + lora_sampler.create_lora_weights(max_loras, lora_config) + + return linear, sampler, lora_sampler + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, sampler, lora_sampler = create_random_sampler_layer() + + # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_sampler, + layer_weights=linear.weight, + generate_embeddings_tensor=1024, + ) + embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor + embeddings_tensor_len = embeddings_tensor.shape[0] + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=8 * num_loras, # * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + input_ = torch.rand(20, 1024, device="cuda") + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 32000, + lora_config.lora_extra_vocab_size, + ) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=linear.weight, + embedding_bias=None) + + original_weight = linear.weight.clone() + + linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + + embeddings_tensor_len] = embeddings_tensor + + sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = sampler._get_logits(hidden_states=input_, + embedding=linear.weight, + embedding_bias=None) + result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + sampler.org_vocab_size = 32000 + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_sampler.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=8 * num_loras * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 32000, + lora_config.lora_extra_vocab_size) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None)[:, :32000] + expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("orientation", ["row", "column"]) +def test_linear_parallel(dist_init, num_loras, orientation) -> None: + + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_random_linear_parallel_layer(): + if orientation == "row": + linear = RowParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRARowParallelLinear(linear) + else: + linear = ColumnParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAColumnParallelLinear(linear) + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_parallel_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("repeats", [2, 3]) +def test_column_parallel_packed(dist_init, num_loras, repeats) -> None: + lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + max_loras = 8 + + def create_column_parallel_packed_layer(): + if repeats == 2: + linear = MergedColumnParallelLinear(4096, [4096] * repeats, + bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAMergedColumnParallelLinear2Slice(linear) + else: + linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = LoRAQKVParallelLinear(linear) + + @dataclass + class FakeConfig: + hidden_size = 4096 + num_key_value_heads = 32 + num_attention_heads = 32 + + lora_linear.create_lora_weights(max_loras, + lora_config, + model_config=FakeConfig()) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + + linear, lora_linear = create_column_parallel_packed_layer() + + lora_dict, sublora_dict = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + repeats=repeats, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * ( + i + 1 + )] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py new file mode 100644 index 0000000000000..756fc55246092 --- /dev/null +++ b/tests/lora/test_llama.py @@ -0,0 +1,141 @@ +import pytest +import ray +import torch + +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + + +def do_sample(llm, lora_path: str, lora_id: int): + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]" + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [1, 2, 4]) +def test_llama_lora(sql_lora_files, tp_size): + if torch.cuda.device_count() < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=tp_size, + worker_use_ray=True) + + expected_no_lora_output = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", + ] + expected_lora_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " + ] + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output + + print("removing lora") + + +def test_llama_tensor_parallel_equality(sql_lora_files): + if torch.cuda.device_count() < 4: + pytest.skip(f"Not enough GPUs for tensor parallelism {4}") + + llm_tp1 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=1, + worker_use_ray=True) + output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) + + del llm_tp1 + ray.shutdown() + + llm_tp2 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=2, + worker_use_ray=True) + output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) + + del llm_tp2 + ray.shutdown() + + assert output_tp1 == output_tp2 + + llm_tp4 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + tensor_parallel_size=4, + worker_use_ray=True) + output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) + + del llm_tp4 + ray.shutdown() + + assert output_tp1 == output_tp4 + + +def test_llama_lora_warmup(sql_lora_files): + """Test that the LLM initialization works with a warmup LORA path and is more conservative""" + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_lora(): + llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) + num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_lora_warmup + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_no_lora(): + llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) + num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_no_lora_warmup + + num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) + num_gpu_blocks_no_lora_warmup = ray.get( + get_num_gpu_blocks_no_lora.remote()) + assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( + "The warmup with lora should be more" + " conservative than without lora, therefore the number of memory blocks for the KV cache should be " + "less when using lora than when not using lora") diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py new file mode 100644 index 0000000000000..b86f7a480e749 --- /dev/null +++ b/tests/lora/test_lora.py @@ -0,0 +1,224 @@ +import pytest +import torch + +from vllm.lora.layers import _apply_lora, _apply_lora_packed_2slice, _apply_lora_packed_3slice + +from .utils import DummyLoRAManager + +TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] +QKV_TENSOR_SIZES = [ + (8192, 1024, 1024), + (8192 // 8, 1024 // 8, 1024 // 8), + (4096, 4096, 4096), + (4096 // 2, 4096 // 2, 4096 // 2), +] +BATCH_SIZES = [8, 32, 256] +RANKS = [8] +DTYPES = [torch.float16] +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora(m, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name, weight, rank=rank) + lora = manager.get_module_lora(module_name) + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = input @ lora.lora_a @ lora.lora_b * lora.scaling + + lora_a_stack = torch.zeros(8, + 1, + lora.lora_a.shape[1], + lora.lora_a.shape[0], + device="cuda", + dtype=dtype) + lora_b_stack = torch.zeros(8, + 1, + lora.lora_b.shape[1], + lora.lora_b.shape[0], + device="cuda", + dtype=dtype) + for i in range(lora_a_stack.shape[0]): + lora_a_stack[i][0] = lora.lora_a.T + lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora( + input, lora_a_stack, lora_b_stack, + torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"), + output) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.full((len(input), ), -1, device="cuda"), output) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: + if m % 2 != 0: + pytest.skip("m must be divisible by 2") + if m // 2 not in TENSOR_SIZES: + pytest.skip("m//2 must be in TENSOR_SIZES") + + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m // 2, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "1", weight, rank=rank) + lora_1 = manager.get_module_lora(module_name + "1") + manager.init_random_lora(module_name + "2", weight, rank=rank) + lora_2 = manager.get_module_lora(module_name + "2") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, + input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_a.shape[1], + lora_1.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_b.shape[1], + lora_1.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_1.lora_a.T + lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T + lora_a_stacks[1][i][0] = lora_2.lora_a.T + lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora_packed_2slice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, m // 2) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_2slice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, m // 2) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype) + weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "q", weight_q, rank=rank) + lora_q = manager.get_module_lora(module_name + "q") + manager.init_random_lora(module_name + "k", weight_kv, rank=rank) + lora_k = manager.get_module_lora(module_name + "k") + manager.init_random_lora(module_name + "v", weight_kv, rank=rank) + lora_v = manager.get_module_lora(module_name + "v") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, + input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, + input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_a.shape[1], + lora_q.lora_a.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_a.shape[1], + lora_k.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_b.shape[1], + lora_q.lora_b.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_b.shape[1], + lora_k.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_q.lora_a.T + lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T + lora_a_stacks[1][i][0] = lora_k.lora_a.T + lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T + lora_a_stacks[2][i][0] = lora_v.lora_a.T + lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T + + output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) + _apply_lora_packed_3slice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, (qkv[0], qkv[1])) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_3slice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, (qkv[0], qkv[1])) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py new file mode 100644 index 0000000000000..de7b245ad4e79 --- /dev/null +++ b/tests/lora/test_lora_manager.py @@ -0,0 +1,473 @@ +import os +from typing import List + +import pytest +import torch +from safetensors.torch import load_file +from torch import nn + +from vllm.config import LoRAConfig +from vllm.lora.layers import (LoRAColumnParallelLinear, LoRARowParallelLinear, + LoRAMergedColumnParallelLinear2Slice) +from vllm.lora.lora import LoRA, PackedLoRA +from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, LoRAMapping) +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, + WorkerLoRAManager) +from vllm.model_executor.layers.linear import RowParallelLinear + + +def test_from_lora_tensors(sql_lora_files): + tensors = load_file( + os.path.join(sql_lora_files, "adapter_model.safetensors")) + new_embeddings = load_file( + os.path.join(sql_lora_files, "new_embeddings.safetensors")) + lora_model = LoRAModel.from_lora_tensors(1, + 8, + 16, + tensors, + "cuda", + embeddings=new_embeddings) + for module_name, lora in lora_model.loras.items(): + assert lora.module_name == module_name + assert lora.rank == 8 + assert lora.lora_alpha == 16 + assert lora.lora_a is not None + assert lora.lora_b is not None + assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] + ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + assert lora.lora_a.shape[1] == 8 + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), None) + if embeddings_module: + assert torch.equal( + lora.embeddings_tensor, + new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( + device=lora.embeddings_tensor.device)) + else: + assert lora.embeddings_tensor is None + + +def create_lora(lora_id: int, model: nn.Module, + sub_modules: List[str]) -> LoRAModel: + loras = {} + for name in sub_modules: + w = model.get_submodule(name).weight + loras[name] = LoRA( + name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0]], device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def create_packed_lora( + lora_id: int, + model: nn.Module, + module_name, + replaced_module_names, + empty_replaced_module_name=None, +) -> LoRAModel: + w = model.get_submodule(module_name).weight + loras = {} + for replaced_module_name in replaced_module_names: + if replaced_module_name == empty_replaced_module_name: + continue + loras[replaced_module_name] = LoRA( + replaced_module_name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0] // len(replaced_module_names)], + device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def test_replace_submodules(dist_init, dummy_model): + model = dummy_model + manager = LoRAModelManager(model, + 1, + 1, + 1, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=8, + max_loras=8), + lora_target_modules=["dense1", "layer1.dense2"]) + model = manager.model + + assert isinstance(model.get_submodule("dense1"), LoRAColumnParallelLinear) + assert isinstance(model.get_submodule("layer1.dense1"), + LoRAColumnParallelLinear) + assert isinstance(model.get_submodule("dense2"), RowParallelLinear) + assert isinstance(model.get_submodule("layer1.dense2"), + LoRARowParallelLinear) + + +def test_lora_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_id_to_index) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + with pytest.raises(ValueError): + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_id_to_index[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] is None + assert manager.add_lora(model_lora2) + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] is None + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + + +def test_lora_lru_cache_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_id_to_index) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_id_to_index[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 1 + assert manager.add_lora(model_lora2) + assert manager.deactivate_lora(3) + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 1 + assert manager.activate_lora(2) + assert manager.lora_id_to_index[0] == 2 + assert manager.lora_id_to_index[1] == 1 + assert manager.activate_lora(3) + assert manager.lora_id_to_index[0] == 2 + assert manager.lora_id_to_index[1] == 3 + + +def test_lru_lora_model_manager(dist_init, dummy_model): + # This tests just the LRU cache functionality, everything else is + # tested in test_lora_model_manager + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["dense1", "dense2", "lm_head"]) + + assert all(x is None for x in manager.lora_id_to_index) + + # Add up to capacity + assert manager.add_lora(model_lora1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(1) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {1, 2} + assert manager.lora_id_to_index[0] == 1 + assert manager.lora_id_to_index[1] == 2 + + # Add over capacity + assert manager.add_lora(model_lora3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(3) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 4 + + # Add 3 again to move it to the top and then add 2 + # should return false since it's in already + assert not manager.add_lora(model_lora3) + assert not manager.activate_lora(3) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {3, 2} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 2 + + # Remove manually + assert manager.remove_lora(3) + assert not manager.remove_lora(3) + + assert set(manager.list_loras()) == {2} + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 2 + + assert manager.add_lora(model_lora3) + assert manager.activate_lora(3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_id_to_index[0] == 3 + assert manager.lora_id_to_index[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == {4} + assert manager.lora_id_to_index[0] is None + assert manager.lora_id_to_index[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_id_to_index) + + assert not manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_id_to_index) + + +def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = LRUCacheWorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 3, 4} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 3 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 6, 7, 8} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 7 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 8 + assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 6 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.apply_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + # Should remove every LoRA not specified in the request. + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = WorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 3, 4} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 3 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 4 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 5} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + + worker_lora_manager.apply_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] is None + assert worker_lora_manager._lora_manager.lora_id_to_index[2] is None + + worker_lora_manager.apply_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {6, 7, 8} + assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 8 + assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 6 + assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 7 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.apply_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_packed_loras(dist_init, dummy_model_gate_up): + model = dummy_model_gate_up + model_lora = create_packed_lora( + 1, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"]) + model_lora1 = create_packed_lora( + 2, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"], + empty_replaced_module_name="gate_proj", + ) + + manager = LoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["gate_up_proj"]) + model = manager.model + + assert isinstance(model.get_submodule("gate_up_proj"), + LoRAMergedColumnParallelLinear2Slice) + assert manager.add_lora(model_lora) + assert manager.add_lora(model_lora1) + + packed_lora = model_lora.get_lora("gate_up_proj") + assert packed_lora and isinstance(packed_lora, PackedLoRA) + + assert torch.allclose(packed_lora.lora_a[0], + model_lora.get_lora("gate_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[0], + model_lora.get_lora("gate_proj").lora_b) + assert torch.allclose(packed_lora.lora_a[1], + model_lora.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[1], + model_lora.get_lora("up_proj").lora_b) + + packed_lora1 = model_lora1.get_lora("gate_up_proj") + assert packed_lora1 and isinstance(packed_lora1, PackedLoRA) + + assert packed_lora1.lora_a[0] is None + assert packed_lora1.lora_b[0] is None + assert torch.allclose(packed_lora1.lora_a[1], + model_lora1.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora1.lora_b[1], + model_lora1.get_lora("up_proj").lora_b) diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py new file mode 100644 index 0000000000000..26a7d47933309 --- /dev/null +++ b/tests/lora/test_punica.py @@ -0,0 +1,196 @@ +# Based on code from https://github.com/punica-ai/punica + +import pytest +import torch + +import vllm.lora.punica as punica + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), + torch.float32: (None, None), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def _lora_ref_impl( + y_final: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, +): + y_stage_1 = torch.empty( + (x.size(0), wa_T_all.size(-2)), + dtype=torch.float32, + device=x.device, + ) + bs = x.shape[0] + s = torch.tensor(scale, dtype=torch.float32, device=x.device) + for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): + xi = x[i].unsqueeze(0).to(torch.float32) + wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + + tmp = xi @ wa + y_stage_1[i] = tmp.squeeze(0) + y_final[i] += (tmp @ wb).squeeze(0) * s + return y_final, y_stage_1 + + +H1 = H2 = [ + 128, + 256, + 512, + 1024, + 1280, + 2048, + 2560, + 2752, + 3072, + 3456, + 3584, + 4096, + 5120, + 5504, + 6912, + 7168, + 8192, + 9216, + 10240, + 11008, + 13824, + 14336, + 32000, + 32256, +] +SEED = [0xabcdabcd987] + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness(dtype_str, h1, h2, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all = torch.randn(num_loras, + num_layers, + h2, + r, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) + + y_our = y.clone() + punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx, + scale) + + assert_close(y_ref, y_our) + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness_slice(dtype_str, h1, h2, seed): + if h2 % 3 != 0 or h2 // 3 not in H1: + pytest.skip("h2 must be divisible by 3 and in supported shapes") + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all_0 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_1 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_2 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all_0 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_1 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_2 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + s = h2 // 3 + + y_ref = y.clone() + _lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale) + + y_our = y.clone() + punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale, 0, s) + punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale, s, s) + punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale, s * 2, s) + + assert_close(y_ref[:, :s], y_our[:, :s]) + assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2]) + assert_close(y_ref[:, s * 2:], y_our[:, s * 2:]) diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py new file mode 100644 index 0000000000000..af0fc41f3fa45 --- /dev/null +++ b/tests/lora/test_tokenizer.py @@ -0,0 +1,69 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import MultiLoRATokenizer, get_lora_tokenizer + + +@pytest.mark.asyncio +async def test_transformers_tokenizer(): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = MultiLoRATokenizer( + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=None) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=None) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + +@pytest.mark.asyncio +async def test_transformers_tokenizer_lora(sql_lora_files): + reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) + tokenizer = MultiLoRATokenizer( + tokenizer_id="gpt2", + enable_lora=True, + max_num_seqs=1, + max_input_length=None, + ) + lora_request = LoRARequest("1", 1, sql_lora_files) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=lora_request) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=lora_request) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + assert isinstance(tokenizer.get_lora_tokenizer(lora_request), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + lora_request) != tokenizer.get_lora_tokenizer(None) + assert tokenizer.get_lora_tokenizer( + lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request) + + +def test_get_lora_tokenizer(sql_lora_files, tmpdir): + lora_request = None + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer + + lora_request = LoRARequest("1", 1, sql_lora_files) + tokenizer = get_lora_tokenizer(lora_request) + assert tokenizer.get_added_vocab() + + lora_request = LoRARequest("1", 1, str(tmpdir)) + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py new file mode 100644 index 0000000000000..a874a72d919fa --- /dev/null +++ b/tests/lora/test_utils.py @@ -0,0 +1,172 @@ +from collections import OrderedDict + +from torch import nn + +from vllm.lora.utils import (LRUCache, parse_fine_tuned_lora_name, + replace_submodule) + + +def test_parse_fine_tuned_lora_name(): + fixture = { + ("base_model.model.lm_head.lora_A.weight", "lm_head", True), + ("base_model.model.lm_head.lora_B.weight", "lm_head", False), + ( + "base_model.model.model.embed_tokens.lora_embedding_A", + "model.embed_tokens", + True, + ), + ( + "base_model.model.model.embed_tokens.lora_embedding_B", + "model.embed_tokens", + False, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", + "model.layers.9.mlp.down_proj", + True, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", + "model.layers.9.mlp.down_proj", + False, + ), + } + for name, module_name, is_lora_a in fixture: + assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) + + +def test_replace_submodule(): + model = nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(764, 100)), + ("act1", nn.ReLU()), + ("dense2", nn.Linear(100, 50)), + ( + "seq1", + nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(100, 10)), + ("dense2", nn.Linear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", nn.Linear(50, 10)), + ("outact", nn.Sigmoid()), + ])) + + sigmoid = nn.Sigmoid() + + replace_submodule(model, "act1", sigmoid) + assert dict(model.named_modules())["act1"] == sigmoid + + dense2 = nn.Linear(1, 5) + replace_submodule(model, "seq1.dense2", dense2) + assert dict(model.named_modules())["seq1.dense2"] == dense2 + + +class TestLRUCache(LRUCache): + + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache.get(2) == 2 + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py new file mode 100644 index 0000000000000..8c11f6c472ff7 --- /dev/null +++ b/tests/lora/test_worker.py @@ -0,0 +1,56 @@ +import os +import random +import tempfile +from unittest.mock import patch + +from vllm.lora.models import LoRAMapping +from vllm.lora.utils import LoRARequest +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig +from vllm.worker.worker import Worker + + +@patch.dict(os.environ, {"RANK": "0"}) +def test_worker_apply_lora(sql_lora_files): + worker = Worker( + model_config=ModelConfig("meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None), + parallel_config=ParallelConfig(1, 1, False), + scheduler_config=SchedulerConfig(32, 32, 32, 256), + lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, + max_loras=32), + distributed_init_method=f"file://{tempfile.mkstemp()[1]}", + ) + worker.init_model() + worker.load_model() + + worker.apply_loras([], LoRAMapping([], [])) + assert worker.list_loras() == set() + + n_loras = 32 + lora_requests = [ + LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) + ] + + worker.apply_loras(lora_requests, LoRAMapping([], [])) + assert worker.list_loras() == { + lora_request.lora_int_id + for lora_request in lora_requests + } + + for i in range(32): + random.seed(i) + iter_lora_requests = random.choices(lora_requests, + k=random.randint(1, n_loras)) + random.shuffle(iter_lora_requests) + iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] + worker.apply_loras(iter_lora_requests, LoRAMapping([], [])) + assert worker.list_loras().issuperset( + {lora_request.lora_int_id + for lora_request in iter_lora_requests}) diff --git a/tests/lora/utils.py b/tests/lora/utils.py new file mode 100644 index 0000000000000..072a0d957758b --- /dev/null +++ b/tests/lora/utils.py @@ -0,0 +1,88 @@ +from typing import List, Optional + +import torch + +from vllm.lora.lora import LoRA + + +class DummyLoRAManager: + + def __init__(self): + super().__init__() + self._loras = {} + + def set_module_lora(self, module_name: str, lora: LoRA): + self._loras[module_name] = lora + + def get_module_lora(self, module_name: str) -> Optional[LoRA]: + return self._loras.get(module_name, None) + + def init_random_lora(self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0): + lora = LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([weight.shape[1], rank], + dtype=weight.dtype, + device="cuda"), + lora_b=torch.rand([rank, weight.shape[0]], + dtype=weight.dtype, + device="cuda"), + ) + if generate_embeddings_tensor: + lora.embeddings_tensor = torch.rand(5, + generate_embeddings_tensor, + dtype=weight.dtype, + device="cuda") + self.set_module_lora(module_name, lora) + + return lora + + def init_lora(self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None): + lora = LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([input_dim, rank], device="cuda"), + lora_b=torch.rand([rank, output_dim], device="cuda"), + embeddings_tensor=embeddings_tensor, + ) + self.set_module_lora(module_name, lora) + return lora + + def reset_lora(self): + self._loras = {} + + def init_packed_lora( + self, + module_name: str, + input_dim: int, + output_dims: List[int], + noop_lora_index: List[int] = None, + rank=8, + ): + base_loras = [] + noop_lora_index = set(noop_lora_index or []) + + for i, out_dim in enumerate(output_dims): + base_lora = self.init_lora( + module_name + "_000_" + str(i), + input_dim, + out_dim, + rank=rank, + noop=i in noop_lora_index, + ) + base_loras.append(base_lora) + packed_lora = LoRA.pack(base_loras) + self.set_module_lora(module_name, packed_lora) + return packed_lora diff --git a/vllm/config.py b/vllm/config.py index 1adf830ffcc12..eef6e53be2855 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,5 @@ from typing import Optional, Union +from dataclasses import dataclass import os import torch @@ -350,6 +351,36 @@ def _verify_args(self) -> None: f"({self.max_num_seqs}).") +@dataclass +class LoRAConfig: + max_lora_rank: int + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[torch.dtype] = None + lora_extra_vocab_size: int = 256 + max_loras: Optional[int] = None + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + if scheduler_config.max_num_batched_tokens > 65528: + raise ValueError( + "Due to limitations of the custom LoRA CUDA kernel, " + "max_num_batched_tokens must be <= 65528 when " + "LoRA is enabled.") + + self.max_loras = scheduler_config.max_num_seqs + if self.max_cpu_loras is None: + self.max_cpu_loras = scheduler_config.max_num_seqs + elif self.max_cpu_loras < scheduler_config.max_num_seqs: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_num_seqs ({scheduler_config.max_num_seqs})") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index ca28bbdc2fb95..f8fb4c6ea1518 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,10 +1,11 @@ import enum import time -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union, Set from vllm.config import CacheConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory +from vllm.lora.request import LoRARequest from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -36,6 +37,7 @@ def __init__( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], + lora_enabled: bool = False, ) -> None: self.scheduled_seq_groups = scheduled_seq_groups self.prompt_run = prompt_run @@ -47,11 +49,23 @@ def __init__( assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups + if lora_enabled: + self.num_loras = len(set(self.lora_requests)) + self._sort_by_lora_ids() + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) + def _sort_by_lora_ids(self) -> bool: + self.scheduled_seq_groups.sort(key=lambda g: ( + g.lora_request.lora_int_id if g.lora_request else 0, g.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return {g.lora_request for g in self.scheduled_seq_groups} + class Scheduler: @@ -59,9 +73,11 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + lora_enabled: bool = False, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + self.lora_enabled = lora_enabled self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -202,6 +218,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, + lora_enabled=self.lora_enabled, ) return scheduler_outputs @@ -274,6 +291,7 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=[], + lora_enabled=self.lora_enabled, ) return scheduler_outputs @@ -299,6 +317,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + lora_request=seq_group.lora_request, ) seq_group_metadata_list.append(seq_group_metadata) return seq_group_metadata_list, scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 746b0e64ece7b..4d1233c473980 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) @dataclass @@ -33,6 +33,11 @@ class EngineArgs: revision: Optional[str] = None tokenizer_revision: Optional[str] = None quantization: Optional[str] = None + enable_lora: bool = False + max_lora_rank: int = 8 + lora_extra_vocab_size: int = 256 + lora_dtype = 'bfloat16' + lora_max_cpu_loras: int = -1 def __post_init__(self): if self.tokenizer is None: @@ -182,6 +187,30 @@ def add_cli_args( choices=['awq', 'squeezellm', None], default=None, help='Method used to quantize the weights') + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='enable lora adapters') + parser.add_argument('--max-lora-rank', + type=int, + default=16, + help='max LoRA rank') + parser.add_argument('--lora-extra-vocab-size', + type=int, + default=256, + help='LoRA extra vocab size') + parser.add_argument('--lora-dtype', + type=str, + default=EngineArgs.dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help='data type for lora') + parser.add_argument( + '--lora-max-cpu-loras', + type=int, + default=-1, + help=('Maximum number of loras to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) return parser @classmethod @@ -194,7 +223,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, + Optional[LoRAConfig]]: model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, @@ -212,7 +242,13 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, self.max_paddings) - return model_config, cache_config, parallel_config, scheduler_config + lora_config = LoRAConfig( + max_lora_rank=self.max_lora_rank, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.lora_max_cpu_loras if self.lora_max_cpu_loras > + 0 else None) if self.enable_lora else None + return model_config, cache_config, parallel_config, scheduler_config, lora_config @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7dcd2eb632c4c..53bc7080b3273 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,6 +4,7 @@ from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union) +from vllm.lora.request import LoRARequest from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine @@ -198,6 +199,50 @@ async def step_async(self) -> List[RequestOutput]: return self._process_model_outputs(output, scheduler_outputs) + ignored + async def encode_request_async( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = await self.tokenizer.encode_async( + request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + + async def add_request_async( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + ) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + return self.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + ) + async def _run_workers_async( self, method: str, @@ -318,7 +363,7 @@ async def engine_step(self) -> bool: if self.engine_use_ray: await self.engine.add_request.remote(**new_request) else: - self.engine.add_request(**new_request) + await self.engine.add_request_async(**new_request) if finished_requests: await self._engine_abort(finished_requests) @@ -357,6 +402,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> AsyncStream: if self.log_requests: shortened_prompt = prompt @@ -370,7 +416,8 @@ async def add_request( logger.info(f"Received request {request_id}: " f"prompt: {shortened_prompt!r}, " f"sampling params: {sampling_params}, " - f"prompt token ids: {shortened_token_ids}.") + f"prompt token ids: {shortened_token_ids}, " + f"lora_request: {lora_request}.") if not self.is_running: if self.start_engine_loop: @@ -382,12 +429,22 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.engine.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + stream = self._request_tracker.add_request( request_id, prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + arrival_time=arrival_time, + lora_request=lora_request, + ) return stream @@ -396,7 +453,8 @@ async def generate( prompt: Optional[str], sampling_params: SamplingParams, request_id: str, - prompt_token_ids: Optional[List[int]] = None) -> RequestOutput: + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None) -> RequestOutput: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -410,6 +468,7 @@ async def generate( request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. Yields: The output `RequestOutput` objects from the LLMEngine for the @@ -420,11 +479,14 @@ async def generate( arrival_time = time.monotonic() try: - stream = await self.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time) + stream = await self.add_request( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + ) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e33d8aa2a2131..c6e74b1d26586 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -3,8 +3,9 @@ from functools import partial from typing import TYPE_CHECKING, Any, Iterable, List, Optional, Tuple, Union +from vllm.lora.request import LoRARequest from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.ray_utils import RayWorker, initialize_cluster, ray @@ -15,7 +16,7 @@ SequenceGroupMetadata, SequenceGroupOutputs, SequenceOutputs, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) + MultiLoRATokenizer) from vllm.utils import Counter if ray: @@ -65,6 +66,7 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], distributed_init_method: str, placement_group: Optional["PlacementGroup"], log_stats: bool, @@ -90,17 +92,13 @@ def __init__( self.cache_config = cache_config assert self.cache_config.sliding_window == getattr( self.model_config.hf_config, "sliding_window", None) + self.lora_config = lora_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats self._verify_args() - self.tokenizer = get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - tokenizer_revision=model_config.tokenizer_revision, - revision=model_config.revision) + self._init_tokenizer() self.seq_counter = Counter() # Create the parallel GPU workers. @@ -137,6 +135,7 @@ def _init_workers(self, distributed_init_method: str): self.scheduler_config, 0, distributed_init_method, + lora_config=self.lora_config, ) self.workers.append(worker) self._run_workers( @@ -150,6 +149,18 @@ def _init_workers(self, distributed_init_method: str): max_parallel_loading_workers, ) + def _init_tokenizer(self, **kwargs): + init_kwargs = dict( + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + init_kwargs.update(kwargs) + self.tokenizer: MultiLoRATokenizer = MultiLoRATokenizer( + self.model_config.tokenizer, **init_kwargs) + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Lazy import the Worker to avoid importing torch.cuda/xformers @@ -183,6 +194,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", scheduler_config, None, None, + lora_config=self.lora_config, )) self._run_workers( "init_model", @@ -198,6 +210,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) def _init_cache(self) -> None: """Profiles the memory usage and initializes the KV cache.""" @@ -246,6 +262,20 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": log_stats=not engine_args.disable_log_stats) return engine + def encode_request( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + def add_request( self, request_id: str, @@ -253,6 +283,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: """Add a request to the engine's request pool. @@ -270,20 +301,26 @@ def add_request( arrival_time: The arrival time of the request. If None, we use the current monotonic time. """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") if arrival_time is None: arrival_time = time.monotonic() - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(prompt) + prompt_token_ids = self.encode_request( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + lora_request) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + arrival_time, lora_request) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -648,7 +685,7 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( - self.tokenizer, + self.tokenizer.get_lora_tokenizer(seq.lora_request), all_input_ids=seq.get_token_ids(), prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, @@ -689,11 +726,29 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == self.tokenizer.eos_token_id): + if ((not sampling_params.ignore_eos) and seq.get_last_token_id() + == self.tokenizer.get_lora_tokenizer( + seq.lora_request).eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> List[int]: + return self._run_workers("list_loras") + def _run_workers_in_batch( self, workers, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b05ba71c6d352..9061909d72c33 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,6 +3,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.lora.request import LoRARequest from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.outputs import RequestOutput @@ -109,6 +110,7 @@ def generate( sampling_params: Optional[SamplingParams] = None, prompt_token_ids: Optional[List[List[int]]] = None, use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -123,6 +125,7 @@ def generate( prompt_token_ids: A list of token IDs for the prompts. If None, we use the tokenizer to convert the prompts to token IDs. use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. Returns: A list of `RequestOutput` objects containing the generated @@ -149,7 +152,10 @@ def generate( prompt = prompts[i] if prompts is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request(prompt, sampling_params, token_ids) + self._add_request(prompt, + sampling_params, + token_ids, + lora_request=lora_request) return self._run_engine(use_tqdm) def _add_request( @@ -157,10 +163,14 @@ def _add_request( prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], + lora_request: Optional[LoRARequest] = None, ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids) + self.llm_engine.add_request(request_id, + prompt, + sampling_params, + prompt_token_ids, + lora_request=lora_request) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000000000..6ba8b0585847d --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,1002 @@ +# pylint: disable=unused-argument +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear, + MergedColumnParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim + +if TYPE_CHECKING: + pass + + +def _apply_lora( + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + indices: torch.Tensor, + output: torch.Tensor, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: (num_loras, lora_rank, hidden_dim) + lora_b_stacked: (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + return output.view_as(org_output) + + +def _apply_lora_packed_2slice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_dim: int, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of 2 sublayers + (slices) packed together (eg. gate_proj + up_proj -> + gate_up_proj). + + Both slices must have the same size (output_dim), meaning the output + tensor will have size output_dim*2. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 2 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 2 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim*2) + output_dim: scalar + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, + 1.0, 0, output_dim) + add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, + 1.0, output_dim, output_dim) + return output.view_as(org_output) + + +def _apply_lora_packed_3slice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, int], +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of 3 sublayers + (slices) packed together (attention projection). The + first slice (Q) may have different size from the two subsequent + slices (K, V). + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: 2 element tuple of (q_slice_size, kv_slice_size) + """ + org_output = output + if x.ndim == 3: + x = x.view(x.shape[0] * x.shape[1], -1) + if output.ndim == 3: + output = output.view(output.shape[0] * output.shape[1], -1) + add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, + 1.0, 0, output_slices[0]) + add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, + 1.0, output_slices[0], output_slices[1]) + add_lora_slice(output, x, lora_a_stacked[2], lora_b_stacked[2], indices, 0, + 1.0, output_slices[0] + output_slices[1], output_slices[1]) + return output.view_as(org_output) + + +@dataclass +class LoRAMapping: + index_mapping: Tuple[int, ...] + prompt_mapping: Tuple[int, ...] + + def __eq__(self, __value: object) -> bool: + return (isinstance(__value, self.__class__) + and self.prompt_mapping == __value.prompt_mapping + and self.index_mapping == __value.index_mapping) + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + + +class LoRALayer(nn.Module): + + def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, + model_config: PretrainedConfig) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + """Sets the mapping indices.""" + ... + + +class LoRAVocabParallelEmbedding(LoRALayer): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + lora_vocab_start_idx = self.base_layer.org_vocab_size + weights_idx = None + if self.base_layer.vocab_end_index > lora_vocab_start_idx: + # We can start adding lora weights + weights_idx = max( + lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) + self.embeddings_slice = (self.base_layer.vocab_start_index - + self.base_layer.org_vocab_size + + weights_idx, + self.base_layer.vocab_end_index - + self.base_layer.org_vocab_size) + self.embeddings_weights = self.base_layer.weight.data[weights_idx:] + self.embeddings_weights.fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.embeddings_indices = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1]].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + self.embeddings_weights.copy_( + self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2]) + [self.embeddings_slice[0]:self.embeddings_slice[1]]) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.embeddings_indices = embeddings_indices + self.indices_len = indices_len + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], -1) + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + return full_output.view_as(full_output_org) + + +class LoRAColumnParallelLinear(LoRALayer): + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.output_dim = self.lora_b_stacked.shape[1] + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply_weights(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @property + def linear_weights(self): + return self.base_layer.linear_weights + + +class LoRAMergedColumnParallelLinear2Slice(LoRAColumnParallelLinear): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + n_slices = 2 + if not (len(self.base_layer.output_sizes) == n_slices + and self.base_layer.output_sizes[0] + == self.base_layer.output_sizes[1]): + raise ValueError( + "LoRAColumnParallelLinear2Slice requires 2 slices with " + "the same size.") + self.tp_size = get_tensor_model_parallel_world_size() + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0] // 2, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + + self.indices: Optional[torch.Tensor] = None + self.output_dim = self.lora_b_stacked[0].shape[2] + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_b_stacked[1][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[0][:, + start_idx:end_idx], lora_b[1][:, + start_idx:end_idx] + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_2slice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_dim, + ) + return output + + +class LoRAQKVParallelLinear(LoRAColumnParallelLinear): + """ColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + + # q, k, v + self.lora_a_stacked = (torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + )) + self.lora_b_stacked = (torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + )) + + self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size) + self.packed_indices: Optional[torch.Tensor] = None + self.standard_indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[1][index] = 0 + self.lora_a_stacked[2][index] = 0 + self.lora_b_stacked[2][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) + else: + if lora_b[0] is not None: + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_b[1] is not None: + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + if lora_b[2] is not None: + self.lora_b_stacked[2][ + index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( + lora_b[2].T, non_blocking=True) + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + if lora_a[2] is not None: + self.lora_a_stacked[2][ + index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( + lora_a[2].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_3slice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_slices, + ) + return output + + +class LoRARowParallelLinear(LoRALayer): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + if self.base_layer.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.base_layer.weight.shape[1] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return self.base_layer.weight + + +class LoRASampler(LoRALayer): + + def __init__( + self, + base_layer: Sampler, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + self.indices = None + self.indices_padded = None + self.indices_len = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ] = embeddings_tensor + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = sampler_indices + self.indices_padded = sampler_indices_padded + self.indices_len = indices_len + + def _get_logits( + self, + hidden_states: torch.Tensor, + embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + + logits[:, self.base_layer.org_vocab_size:] = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, + self.indices_padded[:self.indices_len[2]]).nan_to_num_( + nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + _apply_lora( + hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[1]], + logits, + ) + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> LoRALayer: + supported_layer_types = { + VocabParallelEmbedding: LoRAVocabParallelEmbedding, + ColumnParallelLinear: LoRAColumnParallelLinear, + QKVParallelLinear: LoRAQKVParallelLinear, + MergedColumnParallelLinear: LoRAMergedColumnParallelLinear2Slice, + RowParallelLinear: LoRARowParallelLinear, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_sampler( + layer: Sampler, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LoRASampler: + ret = LoRASampler(layer, lm_head.embedding_dim, lm_head.weight.dtype, + lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000000000..042a98597ab26 --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,120 @@ +from typing import List, Optional + +import torch + + +class LoRA: + """A LoRA that is composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + @classmethod + def pack(cls, loras: List["LoRA"]) -> "PackedLoRA": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = PackedLoRA( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[1 if lora is not None else None for lora in loras]) + return obj + + def optimize(self) -> "LoRA": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + +class PackedLoRA(LoRA): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[int], + lora_a: List[torch.Tensor], + lora_b: List[torch.Tensor], + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + scaling=scaling, + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ + lora_alpha / self.rank for lora_alpha in self.lora_alphas + ] + + def optimize(self) -> "PackedLoRA": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: + continue + self.lora_b[i] *= self.scaling[i] + self.scaling[i] = 1 + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000000000..913234475b182 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,666 @@ +import copy +import json +import logging +import math +import os +import re +from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type, + Union) + +import safetensors.torch +import torch +from torch import nn + +from vllm.config import LoRAConfig +from vllm.utils import LRUCache + +from vllm.lora.layers import LoRALayer, LoRAMapping, from_layer, from_layer_sampler +from vllm.lora.lora import LoRA +from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule) + +logger = logging.getLogger(__name__) + +PACKED_MODULES_CFG = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], +} + +TARGET_MODULES_QKV = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", +] + +EMBEDDING_MODULES = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", +} + +EMBEDDING_PADDING_MODULES = ["lm_head"] + +_GLOBAL_LORA_ID = 0 + + +def convert_mapping( + mapping: LoRAMapping, lora_id_to_index: List[Optional[int]], + max_loras: int, vocab_size: int, extra_vocab_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_id_to_index: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + indices_len: List of lengths of the above tensors. + """ + indices = list(mapping.index_mapping).copy() + embedding_indices = indices.copy() + lora_indices = indices.copy() + prompt_mapping = [ + lora_id_to_index.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(indices)): + # TODO index can be slow. optimize + lora_idx = (lora_id_to_index.index(indices[i]) + if indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if indices[i] > 0 else 0 + indices[i] = i + lora_indices[i] = lora_idx + + indices = torch.tensor([indices, lora_indices, embedding_indices], + dtype=torch.long, + device="cuda") + prompt_mapping = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size) + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = ( + torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + + (sampler_indices_padded * len(sampler_indices_padded))) + indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1]) + + return (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, indices_len) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +def _create_dummy_lora(module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRA": + lora_a = torch.zeros([input_dim, rank], dtype=dtype, device=device) + lora_b = torch.zeros([rank, output_dim], dtype=dtype, device=device) + embeddings_tensor = torch.rand( + 10, embeddings_tensor_dim, dtype=dtype, + device=device) if embeddings_tensor_dim else None + if str(device) == "cpu": + lora_a = lora_a.pin_memory() + lora_b = lora_b.pin_memory() + if embeddings_tensor is not None: + embeddings_tensor = embeddings_tensor.pin_memory() + return LoRA( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + + +class LoRAModel: + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRA], + ) -> None: + self.id = lora_model_id + assert (lora_model_id > + 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRA] = loras + + def get_lora(self, module_name: str) -> Optional[LoRA]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + rank: int, + lora_alpha: int, + tensors: Dict[str, torch.Tensor], + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + loras: Dict[str, LoRA] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + EMBEDDING_MODULES[embeddings_module]].to( + device=device, dtype=dtype) + if device == "cpu": + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRA(module_name, rank, lora_alpha, None, + None, lora_embeddings_tensor) + if is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if device == "cpu": + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + if any(name in module_name + for name in EMBEDDING_PADDING_MODULES + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if device == "cpu": + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for _, lora in loras.items(): + lora.optimize() + return cls(lora_model_id, rank, loras) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint.""" + lora_config_path = os.path.join(lora_dir, "adapter_config.json") + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + if os.path.isfile(lora_tensor_path): + tensors = safetensors.torch.load_file(lora_tensor_path) + elif os.path.isfile(lora_bin_file_path): + tensors = torch.load(lora_bin_file_path) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path) + + with open(lora_config_path) as f: + config = json.load(f) + rank = config["r"] + lora_alpha = config["lora_alpha"] + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + rank=rank, + lora_alpha=lora_alpha, + tensors=tensors, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + ) + + +class LoRAModelManager: + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + lora_target_modules: the target modules patterns to be adapted. + Support both single module name and a list of module names. + packed_modules_mapping: the mapping for packed modules. vLLM + packs some modules into one module, e.g., qkv_proj + is packed of q_proj, k_proj, and v_proj. These modules + have a single layer in the original model, but they are split + into multiple layers in the adapted model. + """ + self.lora_config = lora_config + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.max_num_seqs + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_id_to_index: List[Optional[int]] = [None] * self._lora_slots + self.vocab_size = vocab_size + self.base_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices = torch.empty(self.max_num_seqs, + dtype=torch.long, + device="cuda") + self.sampler_indices_padded = torch.empty(self.max_num_seqs, + dtype=torch.long, + device="cuda") + self.embeddings_indices = torch.empty(2, + self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.offsets = [] + self.indices_len = [None] * 4 + + self.model: nn.Module = model + self.lora_target_modules: List[str] = ([ + lora_target_modules + ] if isinstance(lora_target_modules, str) else lora_target_modules) + self.lora_target_modules = copy.deepcopy(lora_target_modules) + self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, "LoRALayer"] = {} + self._registered_loras: Dict[int, LoRAModel] = {} + self._active_loras: Dict[int, None] = {} + self._last_mapping = None + self._create_lora_modules() + self.model.lora_manager = self + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def _lora_slots(self) -> int: + return self.max_num_seqs + + def __len__(self) -> int: + return len(self._registered_loras) + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id in self._active_loras: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_id_to_index) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_loras[lora_id] = None + lora_model = self._registered_loras[lora_id] + logger.debug( + f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") + self.lora_id_to_index[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = lora_model.get_lora(module_name) + if module_lora: + module_lora.optimize() + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor) + else: + module.reset_lora(index) + return True + + def _deactivate_lora(self, lora_id: int): + try: + index = self.lora_id_to_index.index(lora_id) + self.lora_id_to_index[index] = None + except ValueError: + pass + + def deactivate_lora(self, lora_id: int) -> bool: + if lora_id in self._active_loras: + self._deactivate_lora(lora_id) + self._active_loras.pop(lora_id) + return True + return False + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + if lora.id not in self._registered_loras: + if len(self._registered_loras) >= self.capacity: + raise RuntimeError("No free LoRA slots.") + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + return True + return False + + def remove_lora(self, lora_id: int) -> bool: + """Remove a LoRAModel from the manager.""" + # TODO: should we check active lora? + self.deactivate_lora(lora_id) + return bool(self._registered_loras.pop(lora_id, None)) + + # TODO see if this can be vectorized + def convert_mapping(self, mapping: LoRAMapping) -> None: + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, + indices_len) = convert_mapping(mapping, self.lora_id_to_index, + self._lora_slots + 1, self.vocab_size, + self.lora_config.lora_extra_vocab_size) + self.base_indices[:base_indices.shape[0]].copy_(base_indices) + self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self.embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + # Maintain the reference + self.indices_len[:] = indices_len + + def set_row_lora_mapping(self, lora_mapping: LoRAMapping) -> None: + if self._last_mapping != lora_mapping: + self.convert_mapping(lora_mapping) + self._last_mapping = lora_mapping + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras) + + def get_lora(self, lora_id: int) -> Optional[LoRAModel]: + return self._registered_loras.get(lora_id, None) + + def remove_all_loras(self) -> bool: + """Remove all LoRAModels from the manager.""" + self._registered_loras.clear() + self.lora_id_to_index = [None] * self._lora_slots + self._active_loras.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name): + continue + + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.capacity, self.lora_config, + self.model.config)) + # (yard1): TODO make this more robust + if "lm_head" in module_name: + sampler_module = self.model.get_submodule("sampler") + new_module = replace_submodule( + self.model, "sampler", + from_layer_sampler(sampler_module, module, self.capacity, + self.lora_config, self.model.config)) + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + new_module.set_mapping(self.base_indices, self.sampler_indices, + self.sampler_indices_padded, + self.embeddings_indices, self.indices_len) + + def register_module(self, module_name: str, module: "LoRALayer"): + assert isinstance(module, LoRALayer) + self.modules[module_name] = module + + def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}) + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name) or not isinstance( + module, LoRALayer): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + if parts[-1] in EMBEDDING_MODULES: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = _create_dummy_lora( + module_name, + input_dim, + output_dim, + rank, + module.base_layer.weight.dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim) + else: + lora = _create_dummy_lora( + module_name, + module.base_layer.weight.shape[1], + module.base_layer.weight.shape[0], + rank, + module.base_layer.weight.dtype, + "cpu", + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras = [] + for r in replacements: + lora = _create_dummy_lora( + module_name + "." + r, + module.base_layer.weight.shape[1], + module.base_layer.weight.shape[0] // len(replacements), + rank, + module.base_layer.weight.dtype, + "cpu", + ) + lora.optimize() + subloras.append(lora) + lora = LoRA.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.lora_target_modules) + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name) + if not replacements: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras = [] + has_replacement = False + for r in new_module_names: + lora = lora_model.get_lora(r) + replacement_loras.append(lora) + if lora: + has_replacement = True + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + lora_model.loras[module_name] = LoRA.pack(replacement_loras) + + +class LoRALRUCache(LRUCache): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_lora_fn = deactivate_lora_fn + + def _on_remove(self, key: Hashable, value: Any): + logger.debug(f"Removing LoRA. int id: {key}") + self.deactivate_lora_fn(key) + return super()._on_remove(key, value) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config, lora_target_modules, + packed_modules_mapping) + self._registered_loras: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_lora) + self._active_loras: LoRALRUCache = LoRALRUCache( + self.max_num_seqs, self._deactivate_lora) + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras.cache) + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + was_added = False + if lora.id not in self._registered_loras: + was_added = True + logger.debug(f"Adding LoRA. Model id: {lora.id}, " + f"int id: {lora.id}") + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + else: + # We always touch to update the LRU cache order + self._registered_loras.touch(lora.id) + return was_added + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_loras and len( + self._active_loras) >= self.max_num_seqs: + self._active_loras.remove_oldest() + result = super().activate_lora(lora_id) + # We always touch to update the LRU cache order + self._active_loras.touch(lora_id) + return result + + def remove_oldest_lora(self) -> bool: + if len(self._registered_loras) > 0: + self._registered_loras.remove_oldest() + return True + return False + + +def create_lora_adapter( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config:LoRAConfig, + target_modules: Union[str, + List[str]] = TARGET_MODULES_QKV, + lora_manager_cls:Type[LoRAModelManager] = LoRAModelManager, **kwargs)\ + -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not getattr(model, "supports_lora", False): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + lora_target_modules=target_modules, + **kwargs) + return lora_manager diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py new file mode 100644 index 0000000000000..ac96931b2d071 --- /dev/null +++ b/vllm/lora/punica.py @@ -0,0 +1,173 @@ +# Based on code from https://github.com/punica-ai/punica + +from typing import Optional + +import torch + +import_exc = None + +try: + import vllm._punica_C as punica_kernels +except ImportError as e: + import_exc = e + +if import_exc is None: + + def bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + ): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight + matrices. + indicies: Shape: `[B]`. Indices of the weight matrices. + layer_idx: Layer index of the weight matrices. + scale: Scaling factor. + """ + punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + + def add_lora(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + *, + buffer: Optional[torch.Tensor] = None): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + buffer: Optional. Shape: `[B, R]`. Temporary buffer. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical innacuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, + 1.0) + punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, + scale) + + def add_lora_slice(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + y_offset: int, + y_slice_size: int, + *, + buffer: Optional[torch.Tensor] = None): + """ + Same as `add_lora` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical inaccuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv_low_level( + buffer, + x, + wa_t_all, + indicies, + layer_idx, + 1.0, + x.size(1), + buffer.size(1), + 0, + ) + punica_kernels.dispatch_bgmv_low_level( + y, + buffer, + wb_t_all, + indicies, + layer_idx, + scale, + buffer.size(1), + y_slice_size, + y_offset, + ) + +else: + + def _raise_exc( + *args, # pylint: disable=unused-argument + **kwargs # pylint: disable=unused-argument + ): + if torch.cuda.get_device_capability() < (8, 0): + raise ImportError( + "LoRA kernels require compute capability>=8.0") from import_exc + else: + raise import_exc + + bgmv = _raise_exc + add_lora = _raise_exc + add_lora_slice = _raise_exc + +__all__ = [ + "bgmv", + "add_lora", + "add_lora_slice", +] diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000000000..3ae5be59b1b88 --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + + +@dataclass +class LoRARequest: + lora_id: str + lora_int_id: int + lora_local_path: str + + def __post_init__(self): + if self.lora_int_id < 1: + raise ValueError( + f"lora_int_id must be > 0, got {self.lora_int_id}") + + def __eq__(self, value: object) -> bool: + return isinstance(value, LoRARequest) and self.lora_id == value.lora_id + + def __hash__(self) -> int: + return self.lora_int_id diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000000000..f67a3812fb046 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,39 @@ +import logging +from typing import Tuple + +from torch import nn + +logger = logging.getLogger(__name__) + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + """ + parts = name.split(".") + assert parts[0] == "base_model" + assert parts[1] == "model" + if parts[-1] == "weight": + assert parts[-2] == "lora_A" or parts[-2] == "lora_B" + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + + raise ValueError(f"{name} is unsupported format") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000000000..be6f4cf0589bd --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,266 @@ +import logging +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, List, Optional, Set, Type, Union + +import torch + +from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_adapter) +from vllm.lora.request import LoRARequest +from vllm.lora.layers import LoRAMapping +from vllm.config import LoRAConfig + +logger = logging.getLogger(__name__) + + +class AbstractWorkerLoRAManager(ABC): + """Abstract class for managing LoRA models on the worker side.""" + + def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, + vocab_size: int, lora_config: LoRAConfig, + device: torch.device): + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.device = device + self.lora_config = lora_config + + @abstractproperty + def is_enabled(self) -> bool: + ... + + @abstractmethod + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + ... + + @abstractmethod + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + ... + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + ... + + @abstractmethod + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + ... + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + ... + + @abstractmethod + def remove_all_loras(self) -> bool: + ... + + @abstractmethod + def list_loras(self) -> Set[int]: + ... + + +class DisabledWorkerLoRAManager(AbstractWorkerLoRAManager): + """WorkerLoRAManager that does nothing.""" + + @property + def is_enabled(self) -> bool: + return False + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + return model + + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + return + + def add_lora(self, lora_request: LoRARequest) -> bool: + return False + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + return False + + def remove_lora(self, lora_id: int) -> bool: + return False + + def remove_all_loras(self) -> bool: + return + + def list_loras(self) -> Set[int]: + return set() + + +class WorkerLoRAManager(AbstractWorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_model_cls: Type[LoRAModel] = LoRAModel, + ): + self._lora_manager: Optional[LoRAModelManager] = None + self._lora_model_cls = lora_model_cls + super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, + lora_config, device) + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_adapter( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + target_modules=target_modules, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + lora_manager_cls=self._lora_manager_cls, + ) + self._lora_manager = lora_manager + return lora_manager.model + + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self._apply_loras(lora_requests) + self._lora_manager.set_row_lora_mapping(lora_mapping) + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_that_exist = self.list_loras() + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.max_num_seqs: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.max_num_seqs}).") + + new_loras = set(loras_map) + loras_to_add = new_loras - loras_that_exist + loras_to_remove = loras_that_exist - new_loras + + for lora_id in loras_to_remove: + self.remove_lora(lora_id) + + for lora_id in loras_to_add: + self.add_lora(loras_map[lora_id]) + + def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + try: + lora = self._lora_model_cls.from_local_checkpoint( + lora_request.lora_local_path, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + ) + except Exception as e: + raise RuntimeError( + f"Loading lora {lora_request.lora_local_path} failed") from e + if lora.rank > self.lora_config.max_lora_rank: + raise ValueError( + f"LoRA rank {lora.rank} is greater than max_lora_rank " + f"{self.lora_config.max_lora_rank}.") + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + return self._lora_manager.add_lora( + self._lora_manager.create_dummy_lora(lora_request.lora_int_id, + rank)) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + self._lora_manager.activate_lora(lora.id) + return loaded + + def remove_lora(self, lora_id: int) -> bool: + return self._lora_manager.remove_lora(lora_id) + + def remove_all_loras(self) -> bool: + self._lora_manager.remove_all_loras() + + def list_loras(self) -> Set[int]: + return set(self._lora_manager.list_loras()) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _lora_manager_cls: Type[ + LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_adapter( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_adapter( + model, + target_modules=target_modules, + lora_manager_cls=self._lora_manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._lora_manager = lora_manager + return lora_manager.model + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.max_num_seqs: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.max_num_seqs}).") + for lora in loras_map.values(): + self.add_lora(lora) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_loras(): + # Remove before we load the new lora to save memory + if len(self._lora_manager) + 1 > self._lora_manager.capacity: + self._lora_manager.remove_oldest_lora() + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + self._lora_manager.activate_lora(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index c874ec5921155..5bce287a92ae5 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -29,9 +29,24 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, vocab_size: int) -> None: + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None) -> None: super().__init__() self.vocab_size = vocab_size + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_all_gather(logits) + # Remove paddings in vocab (if any). + logits = logits[:, :self.org_vocab_size] + return logits def forward( self, @@ -44,8 +59,7 @@ def forward( hidden_states = _prune_hidden_states(hidden_states, input_metadata) # Get the logits for the next tokens. - logits = _get_logits(hidden_states, embedding, embedding_bias, - self.vocab_size) + logits = self._get_logits(hidden_states, embedding, embedding_bias) # Apply logits processors (if any). logits = _apply_logits_processors(logits, input_metadata) @@ -97,19 +111,6 @@ def forward( prompt_logprobs, sample_logprobs) -def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor], - vocab_size: int) -> torch.Tensor: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - logits = logits[:, :vocab_size] - return logits - - def _prune_hidden_states( hidden_states: torch.Tensor, input_metadata: InputMetadata, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b08d5555b0faa..9e4ac26e73d00 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -43,16 +43,19 @@ class VocabParallelEmbedding(torch.nn.Module): num_embeddings: vocabulary size. embedding_dim: size of hidden state. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). """ def __init__(self, num_embeddings: int, embedding_dim: int, - params_dtype: Optional[torch.dtype] = None): + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None): super().__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings + self.org_vocab_size = org_num_embeddings or num_embeddings self.num_embeddings_padded = pad_vocab_size(num_embeddings) self.embedding_dim = embedding_dim if params_dtype is None: @@ -77,7 +80,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): parallel_dim = param.parallel_dim - assert loaded_weight.shape[parallel_dim] == self.num_embeddings + assert loaded_weight.shape[parallel_dim] == self.org_vocab_size loaded_weight = loaded_weight[self.vocab_start_index:self. vocab_end_index] param[:loaded_weight.shape[0]].data.copy_(loaded_weight) @@ -114,14 +117,17 @@ class ParallelLMHead(VocabParallelEmbedding): embedding_dim: size of hidden state. bias: whether to use bias. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). """ def __init__(self, num_embeddings: int, embedding_dim: int, bias: bool = False, - params_dtype: Optional[torch.dtype] = None): - super().__init__(num_embeddings, embedding_dim, params_dtype) + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 54b87c4b866e3..cf84b9810c575 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -1,12 +1,12 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Type +from typing import Optional, Type import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig +from vllm.config import ModelConfig, LoRAConfig from vllm.model_executor.models import * from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -58,7 +58,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {list(_MODEL_REGISTRY.keys())}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, + lora_config: Optional[LoRAConfig] = None) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. @@ -87,7 +88,12 @@ def get_model(model_config: ModelConfig) -> nn.Module: with _set_default_torch_dtype(model_config.dtype): # Create a model instance. # The weights will be initialized as empty tensors. - model = model_class(model_config.hf_config, linear_method) + # TODO(yard1): Clean this up (lora_config) + try: + model = model_class(model_config.hf_config, linear_method, + lora_config) + except TypeError: + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": model = model.cuda() # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 8e7344da4888e..999c1097d0a42 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -43,6 +43,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -223,14 +224,19 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ LlamaDecoderLayer(config, linear_method) @@ -264,18 +270,25 @@ def forward( class LlamaForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead(unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index d18572610741c..c67c3fae2028a 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -43,6 +43,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -217,15 +218,20 @@ def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ MistralDecoderLayer(config, linear_method) @@ -259,18 +265,27 @@ def forward( class MistralForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = MistralModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = MistralModel(config, + linear_method, + lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead(unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/outputs.py b/vllm/outputs.py index fe54926e06e64..534e9d5ea8a53 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -2,6 +2,7 @@ from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, SequenceStatus) +from vllm.lora.request import LoRARequest class CompletionOutput: @@ -16,6 +17,7 @@ class CompletionOutput: logprobs: The log probabilities of the top probability words at each position if the logprobs are requested. finish_reason: The reason why the sequence is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -26,6 +28,7 @@ def __init__( cumulative_logprob: float, logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: self.index = index self.text = text @@ -33,6 +36,7 @@ def __init__( self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs self.finish_reason = finish_reason + self.lora_request = lora_request def finished(self) -> bool: return self.finish_reason is not None @@ -56,6 +60,7 @@ class RequestOutput: prompt_logprobs: The log probabilities to return per prompt token. outputs: The output sequences of the request. finished: Whether the whole request is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -66,6 +71,7 @@ def __init__( prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -73,6 +79,7 @@ def __init__( self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished + self.lora_request = lora_request @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -108,8 +115,13 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": prompt_token_ids = seq_group.prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() - return cls(seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished) + return cls(seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + lora_request=seq_group.lora_request) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -117,4 +129,5 @@ def __repr__(self) -> str: f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " - f"finished={self.finished})") + f"finished={self.finished}, " + f"lora_request={self.lora_request})") diff --git a/vllm/sequence.py b/vllm/sequence.py index ecfaee6e8c3d6..06170ab79d69a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,6 +5,7 @@ from vllm.block import LogicalTokenBlock from vllm.sampling_params import SamplingParams +from vllm.lora.request import LoRARequest PromptLogprobs = List[Optional[Dict[int, float]]] SampleLogprobs = List[Dict[int, float]] @@ -105,6 +106,7 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + lora_request: LoRA request. """ def __init__( @@ -113,10 +115,12 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size + self.lora_request = lora_request self.data = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -228,6 +232,7 @@ class SequenceGroup: seqs: The list of sequences. sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. + lora_request: LoRA request. """ def __init__( @@ -236,11 +241,13 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.lora_request = lora_request self.prompt_logprobs: Optional[PromptLogprobs] = None @property @@ -335,6 +342,7 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + lora_request: LoRA request. """ def __init__( @@ -344,12 +352,18 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.is_prompt = is_prompt self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.lora_request = lora_request + + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 class SequenceOutputs: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 5b0481480a63b..b84f50c3bd5d7 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -4,6 +4,8 @@ PreTrainedTokenizerFast) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.utils import make_async, LRUCache logger = init_logger(__name__) @@ -69,6 +71,86 @@ def get_tokenizer( return tokenizer +def get_lora_tokenizer(lora_request: LoRARequest, *args, + **kwargs) -> Optional[PreTrainedTokenizer]: + if lora_request is None: + return None + try: + tokenizer = get_tokenizer(lora_request.lora_local_path, *args, + **kwargs) + except OSError as e: + # No tokenizer was found in the LoRA folder, + # use base model tokenizer + logger.warning( + f"No tokenizer found in {lora_request.lora_local_path}, " + "using base model tokenizer instead. " + f"(Exception: {str(e)})") + tokenizer = None + return tokenizer + + +get_lora_tokenizer_async = make_async(get_lora_tokenizer) + + +class MultiLoRATokenizer: + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def ping(self): + return True + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], diff --git a/vllm/utils.py b/vllm/utils.py index 47e51048fed45..9282db842c1d2 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -4,9 +4,20 @@ import psutil import torch +import asyncio +from functools import partial +from typing import ( + Awaitable, + Callable, + TypeVar, +) +from collections import OrderedDict +from typing import Any, Hashable, Optional from vllm._C import cuda_utils +T = TypeVar("T") + class Device(enum.Enum): GPU = enum.auto() @@ -27,6 +38,69 @@ def reset(self) -> None: self.counter = 0 +class LRUCache: + + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def __getitem__(self, key: Hashable) -> Any: + return self.get(key) + + def __setitem__(self, key: Hashable, value: Any) -> None: + self.put(key, value) + + def __delitem__(self, key: Hashable) -> None: + self.pop(key) + + def touch(self, key: Hashable) -> None: + self.cache.move_to_end(key) + + def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + else: + value = default_value + return value + + def put(self, key: Hashable, value: Any) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + self._remove_old_if_needed() + + def _on_remove(self, key: Hashable, value: Any): + pass + + def remove_oldest(self): + if not self.cache: + return + key, value = self.cache.popitem(last=False) + self._on_remove(key, value) + + def _remove_old_if_needed(self) -> None: + while len(self.cache) > self.capacity: + self.remove_oldest() + + def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + run_on_remove = key in self.cache + value = self.cache.pop(key, default_value) + if run_on_remove: + self._on_remove(key, value) + return value + + def clear(self): + while len(self.cache) > 0: + self.remove_oldest() + self.cache.clear() + + def get_max_shared_memory_bytes(gpu: int = 0) -> int: """Returns the maximum shared memory per thread block in bytes.""" # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html @@ -53,3 +127,19 @@ def random_uuid() -> str: def in_wsl() -> bool: # Reference: https://github.com/microsoft/WSL/issues/4071 return "microsoft" in " ".join(uname()).lower() + + +def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args, **kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=None, func=p_func) + + return _async_wrapper diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 702767ebd8d09..d316b9588bf75 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,12 +1,13 @@ """A GPU worker class.""" +import gc import os -from typing import Dict, List, Tuple, Optional +from typing import Dict, List, Tuple, Set, Optional import torch import torch.distributed from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.model_executor import get_model, InputMetadata, set_random_seed from vllm.model_executor.parallel_utils.parallel_state import ( initialize_model_parallel) @@ -14,6 +15,14 @@ from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.utils import get_gpu_memory +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import ( + DisabledWorkerLoRAManager, + LRUCacheWorkerLoRAManager, +) +from vllm.lora.layers import LoRAMapping + +LORA_WARMUP_RANK = 8 class Worker: @@ -31,12 +40,14 @@ def __init__( scheduler_config: SchedulerConfig, rank: Optional[int] = None, distributed_init_method: Optional[str] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.rank = rank self.distributed_init_method = distributed_init_method + self.lora_config = lora_config # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). @@ -46,6 +57,7 @@ def __init__( self.cache_engine = None self.cache_events = None self.gpu_cache = None + self.lora_manager = None def init_model(self): # This env var set by Ray causes exceptions with graph building. @@ -69,7 +81,21 @@ def init_model(self): set_random_seed(self.model_config.seed) def load_model(self): - self.model = get_model(self.model_config) + self.model = get_model(self.model_config, self.lora_config) + + vocab_size = self.model.config.vocab_size + + if self.lora_config: + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, vocab_size, + self.lora_config, self.device) + self.model = self.lora_manager.create_lora_adapter(self.model) + else: + self.lora_manager = DisabledWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens, vocab_size, + self.lora_config, self.device) @torch.inference_mode() def profile_num_available_blocks( @@ -91,6 +117,24 @@ def profile_num_available_blocks( sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs + + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests = [] + if self.lora_config: + for idx in range(max_num_seqs): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_id=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + seqs = [] for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + @@ -102,11 +146,21 @@ def profile_num_available_blocks( seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, + lora_request=dummy_lora_requests[group_id] + if dummy_lora_requests else None, ) seqs.append(seq) - input_tokens, input_positions, input_metadata = self._prepare_inputs( - seqs) + ( + input_tokens, + input_positions, + input_metadata, + lora_mapping, + prepared_lora_requests, + ) = self._prepare_inputs(seqs) + + if dummy_lora_requests: + self.apply_loras(prepared_lora_requests, lora_mapping) # Execute the model. num_layers = self.model_config.get_num_layers(self.parallel_config) @@ -131,6 +185,8 @@ def profile_num_available_blocks( num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + self.lora_manager.remove_all_loras() + gc.collect() torch.cuda.empty_cache() # Reset the seed to ensure that the random state is not affected by @@ -151,7 +207,8 @@ def init_cache_engine(self, cache_config: CacheConfig) -> None: def _prepare_inputs( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, LoRAMapping, + Set[LoRARequest]]: seq_groups: List[Tuple[List[int], SamplingParams]] = [] input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] @@ -160,6 +217,9 @@ def _prepare_inputs( selected_token_start_idx = 0 categorized_sample_indices = {t: [] for t in SamplingType} categorized_sample_indices_start_idx = 0 + lora_requests: Set[LoRARequest] = set() + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] # Add prompt tokens. prompt_lens: List[int] = [] @@ -170,6 +230,7 @@ def _prepare_inputs( seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) + lora_id = seq_group_metadata.lora_int_id # Use any sequence in the group. seq_id = seq_ids[0] @@ -187,6 +248,17 @@ def _prepare_inputs( categorized_sample_indices_start_idx) categorized_sample_indices_start_idx += 1 + if lora_id > 0: + # if we are preparing inputs for the warmup step, we want the + # lora computation to take up the maximum possible amount of + # memory that way we can get a tighter upper bound on the + # amount of memory we can use and therefore not oom. If + # for_warmup is true, we add the lora lora mapping that is used + # during generation. + lora_requests.add(seq_group_metadata.lora_request) + lora_index_mapping.append([lora_id] * prompt_len) + lora_prompt_mapping.append(lora_id) + input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt # is always the first token in the sequence. @@ -233,6 +305,7 @@ def _prepare_inputs( seq_ids = list(seq_group_metadata.seq_data.keys()) sampling_params = seq_group_metadata.sampling_params seq_groups.append((seq_ids, sampling_params)) + lora_id = seq_group_metadata.lora_int_id num_seqs = len(seq_ids) selected_token_indices.extend( @@ -255,6 +328,7 @@ def _prepare_inputs( if self.sliding_window is not None: context_len = min(context_len, self.sliding_window) input_positions.append([position]) + lora_index_mapping.append([lora_id]) block_table = seq_group_metadata.block_tables[seq_id] @@ -274,6 +348,11 @@ def _prepare_inputs( block_table = block_table[-sliding_window_blocks:] generation_block_tables.append(block_table) + # Update LoRA mapping. + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + lora_prompt_mapping.append(lora_id) + padded_input_tokens = [ _pad_to_max(tokens, max_seq_len, pad=0) for tokens in input_tokens ] @@ -281,6 +360,10 @@ def _prepare_inputs( _pad_to_max(positions, max_seq_len, pad=0) for positions in input_positions ] + padded_lora_input_mapping = [ + _pad_to_max(mapping, max_seq_len, pad=0) + for mapping in lora_index_mapping + ] padded_slot_mapping = [ _pad_to_max(mapping, max_seq_len, pad=-1) for mapping in slot_mapping @@ -318,6 +401,14 @@ def _prepare_inputs( for seq_group_metadata in seq_group_metadata_list: seq_data.update(seq_group_metadata.seq_data) + flat_padded_lora_input_mapping = [ + item for sublist in padded_lora_input_mapping for item in sublist + ] + lora_mapping = LoRAMapping( + flat_padded_lora_input_mapping, + lora_prompt_mapping, + ) + input_metadata = InputMetadata( seq_groups=seq_groups, seq_data=seq_data, @@ -330,7 +421,7 @@ def _prepare_inputs( categorized_sample_indices=categorized_sample_indices, sliding_window=self.sliding_window, ) - return tokens_tensor, positions_tensor, input_metadata + return tokens_tensor, positions_tensor, input_metadata, lora_mapping, lora_requests @torch.inference_mode() def execute_model( @@ -362,8 +453,20 @@ def execute_model( return {} # Prepare input tensors. - input_tokens, input_positions, input_metadata = self._prepare_inputs( - seq_group_metadata_list) + ( + input_tokens, + input_positions, + input_metadata, + lora_mapping, + lora_requests, + ) = self._prepare_inputs(seq_group_metadata_list) + + if self.lora_config: + lora_requests = [ + seq_group_metadata.lora_request + for seq_group_metadata in seq_group_metadata_list + ] + self.apply_loras(lora_requests, lora_mapping) # Execute the model. output = self.model( @@ -375,6 +478,19 @@ def execute_model( ) return output + def apply_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self.lora_manager.apply_loras(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.lora_manager.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.lora_manager.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.lora_manager.list_loras() + def _init_distributed_environment( parallel_config: ParallelConfig, From dd1726f61d1b356fafca130458ae869fb05665e4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 27 Nov 2023 17:15:04 -0800 Subject: [PATCH 02/35] Lint --- vllm/model_executor/model_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index bfe72bbf8c1e5..dc3299dc36f29 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -102,5 +102,5 @@ def get_model(model_config: ModelConfig, else: # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + model_config.load_format, model_config.revision) return model.eval() From 6c66b6e7527a3f23a0af6b9777c6a71e52c52ac2 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 10:56:35 -0800 Subject: [PATCH 03/35] Add rank check --- csrc/punica/bgmv/bgmv_config.h | 6 ++++-- vllm/config.py | 8 ++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 3fd56b685be13..2c77663c0c617 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -44,10 +44,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ +// Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 128) // clang-format on diff --git a/vllm/config.py b/vllm/config.py index eef6e53be2855..2b0a767ff08bf 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -359,6 +359,14 @@ class LoRAConfig: lora_extra_vocab_size: int = 256 max_loras: Optional[int] = None + def __post_init__(self): + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + possible_max_ranks = (8, 16, 32, 64, 128) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + def verify_with_model_config(self, model_config: ModelConfig): if self.lora_dtype in (None, "auto"): self.lora_dtype = model_config.dtype From 70eaca69ca4b4069c74d58f79d457f8d733b3d08 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 11:54:44 -0800 Subject: [PATCH 04/35] Add example, minor tweaks --- examples/multilora_inference.py | 91 +++++++++++++++++++++++++++++ vllm/engine/arg_utils.py | 12 ++-- vllm/engine/llm_engine.py | 18 ++++-- vllm/lora/models.py | 4 +- vllm/lora/request.py | 12 ++++ vllm/model_executor/model_loader.py | 5 +- vllm/worker/worker.py | 4 +- 7 files changed, 128 insertions(+), 18 deletions(-) create mode 100644 examples/multilora_inference.py diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py new file mode 100644 index 0000000000000..65885e534b508 --- /dev/null +++ b/examples/multilora_inference.py @@ -0,0 +1,91 @@ +""" +This example shows how to use the multi-LoRA functionality for offline inference. + +Requires HuggingFace credentials for access to Llama2. +""" + +from typing import Optional, List, Tuple + +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput +from vllm.lora.request import LoRARequest + + +def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 2 requests for the LoRA. + + In this example, we only use one LoRA adapter. However, we could + specify multiple adapters and use them in the same way. + """ + return [ + ("A robot may not injure a human being", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), None), + ("To be or not to be,", + SamplingParams(temperature=0.8, + top_k=5, + presence_penalty=0.2, + max_tokens=128), None), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + print(request_output) + + +def initialize_engine() -> LLMEngine: + """Initialize the LLMEngine.""" + engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_num_seqs=32) + return LLMEngine.from_engine_args(engine_args) + + +def main(): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine() + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + process_requests(engine, test_prompts) + + +if __name__ == '__main__': + main() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4d1233c473980..6dc695aaa554c 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,9 +34,9 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enable_lora: bool = False - max_lora_rank: int = 8 + max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 - lora_dtype = 'bfloat16' + lora_dtype = 'auto' lora_max_cpu_loras: int = -1 def __post_init__(self): @@ -193,21 +193,21 @@ def add_cli_args( help='enable lora adapters') parser.add_argument('--max-lora-rank', type=int, - default=16, + default=EngineArgs.max_lora_rank, help='max LoRA rank') parser.add_argument('--lora-extra-vocab-size', type=int, - default=256, + default=EngineArgs.lora_extra_vocab_size, help='LoRA extra vocab size') parser.add_argument('--lora-dtype', type=str, - default=EngineArgs.dtype, + default=EngineArgs.lora_dtype, choices=['auto', 'float16', 'bfloat16', 'float32'], help='data type for lora') parser.add_argument( '--lora-max-cpu-loras', type=int, - default=-1, + default=EngineArgs.lora_max_cpu_loras, help=('Maximum number of loras to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index c6e74b1d26586..74e18561e6401 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -369,11 +369,13 @@ def _check_beam_search_early_stopping( current_worst_score = (current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.get_lora_tokenizer( + current_worst_seq.lora_request).eos_token_id)) if early_stopping is False: highest_attainable_score = (best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.get_lora_tokenizer( + best_running_seq.lora_request).eos_token_id)) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -387,7 +389,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.tokenizer.get_lora_tokenizer( + best_running_seq.lora_request).eos_token_id, seq_len=max_possible_length)) else: # Otherwise, beam search will prefer shorter sequences. The @@ -396,7 +399,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.get_lora_tokenizer( + best_running_seq.lora_request).eos_token_id)) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -487,7 +491,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the finished sequences by their scores. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request + ).eos_token_id), reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: @@ -515,7 +520,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the running sequences by their scores. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request + ).eos_token_id), reverse=True) # Check if we can stop the beam search. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 913234475b182..60034bdbb6e6b 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -321,10 +321,10 @@ def __init__( self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") - self.sampler_indices = torch.empty(self.max_num_seqs, + self.sampler_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") - self.sampler_indices_padded = torch.empty(self.max_num_seqs, + self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, dtype=torch.long, device="cuda") self.embeddings_indices = torch.empty(2, diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 3ae5be59b1b88..5d45f8a0f396d 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -3,6 +3,18 @@ @dataclass class LoRARequest: + """ + Request for a LoRA adapter. + + Note that this class should be be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_id and lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + lora_id: str lora_int_id: int lora_local_path: str diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index dc3299dc36f29..0cd890615c918 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -89,11 +89,10 @@ def get_model(model_config: ModelConfig, # Create a model instance. # The weights will be initialized as empty tensors. with torch.device("cuda"): - # TODO(yard1): Clean this up (lora_config) - try: + if getattr(model_class, "supports_lora", True): model = model_class(model_config.hf_config, linear_method, lora_config) - except TypeError: + else: model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d316b9588bf75..b18668007028f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -257,7 +257,9 @@ def _prepare_inputs( # during generation. lora_requests.add(seq_group_metadata.lora_request) lora_index_mapping.append([lora_id] * prompt_len) - lora_prompt_mapping.append(lora_id) + lora_prompt_mapping.extend( + [lora_id] * + (prompt_len if sampling_params.prompt_logprobs else 1)) input_tokens.append(prompt_tokens) # NOTE(woosuk): Here we assume that the first token in the prompt From a3f191ae554edbf7212c62321bf5ce8d5702e375 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 12:23:59 -0800 Subject: [PATCH 05/35] Fix dummy lora init for packed layers --- vllm/lora/models.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 60034bdbb6e6b..4c01748f52f95 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -503,16 +503,16 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: input_dim, output_dim, rank, - module.base_layer.weight.dtype, + module.lora_a_stacked.dtype, "cpu", embeddings_tensor_dim=embeddings_tensor_dim) else: lora = _create_dummy_lora( module_name, - module.base_layer.weight.shape[1], - module.base_layer.weight.shape[0], + module.lora_a_stacked.shape[-1], + module.lora_b_stacked.shape[-2], rank, - module.base_layer.weight.dtype, + module.lora_a_stacked.dtype, "cpu", ) lora.optimize() @@ -520,13 +520,13 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: parts = module_name.split(".") replacements = self.packed_modules_mapping[parts[-1]] subloras = [] - for r in replacements: + for i, r in enumerate(replacements): lora = _create_dummy_lora( module_name + "." + r, - module.base_layer.weight.shape[1], - module.base_layer.weight.shape[0] // len(replacements), + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], rank, - module.base_layer.weight.dtype, + module.lora_a_stacked[i].dtype, "cpu", ) lora.optimize() From 240cee93b98655b8b6942d66063507c1ef18a967 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 13:09:26 -0800 Subject: [PATCH 06/35] Fix capacity --- vllm/lora/models.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 4c01748f52f95..3f3fe8b997677 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -458,14 +458,14 @@ def _create_lora_modules(self): new_module = replace_submodule( self.model, module_name, - from_layer(module, self.capacity, self.lora_config, + from_layer(module, self._lora_slots, self.lora_config, self.model.config)) # (yard1): TODO make this more robust if "lm_head" in module_name: sampler_module = self.model.get_submodule("sampler") new_module = replace_submodule( self.model, "sampler", - from_layer_sampler(sampler_module, module, self.capacity, + from_layer_sampler(sampler_module, module, self._lora_slots, self.lora_config, self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) From c4d57a531699c6e41926213f02a3008b8d6d5215 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 13:23:19 -0800 Subject: [PATCH 07/35] Lint --- vllm/lora/models.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 3f3fe8b997677..ecb29c94c9821 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -465,8 +465,9 @@ def _create_lora_modules(self): sampler_module = self.model.get_submodule("sampler") new_module = replace_submodule( self.model, "sampler", - from_layer_sampler(sampler_module, module, self._lora_slots, - self.lora_config, self.model.config)) + from_layer_sampler(sampler_module, module, + self._lora_slots, self.lora_config, + self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, From 471f25a6b51e87cec00568b17e294d07c9087603 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 13:39:07 -0800 Subject: [PATCH 08/35] Remove rank 128 for now --- csrc/punica/bgmv/bgmv_config.h | 3 +-- vllm/config.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 2c77663c0c617..da6e6a611ecaa 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -49,7 +49,6 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) \ - FOR_BGMV_WIDE(f, in_T, out_T, W_T, 128) + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) // clang-format on diff --git a/vllm/config.py b/vllm/config.py index 2b0a767ff08bf..0b03565a5031a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -361,7 +361,7 @@ class LoRAConfig: def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h - possible_max_ranks = (8, 16, 32, 64, 128) + possible_max_ranks = (8, 16, 32, 64) if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " From ccbb4b7f88395fea4362d42688a46712160adb0b Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 13:57:38 -0800 Subject: [PATCH 09/35] Pass to scheduler --- vllm/core/scheduler.py | 10 +++++++--- vllm/engine/llm_engine.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f8fb4c6ea1518..fce3f2acb65f5 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -2,7 +2,7 @@ import time from typing import Dict, Iterable, List, Optional, Tuple, Union, Set -from vllm.config import CacheConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory from vllm.lora.request import LoRARequest @@ -73,11 +73,11 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, - lora_enabled: bool = False, + lora_config: Optional[LoRAConfig], ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config - self.lora_enabled = lora_enabled + self.lora_config = lora_config self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -99,6 +99,10 @@ def __init__( # Sequence groups in the SWAPPED state. self.swapped: List[SequenceGroup] = [] + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 74e18561e6401..8ba04181ff47b 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -111,7 +111,7 @@ def __init__( self._init_cache() # Create the scheduler. - self.scheduler = Scheduler(scheduler_config, cache_config) + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) # Logging. self.last_logging_time = 0.0 From 5a1a0be6e06b1f155168023c28c8100fe6464143 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 15:52:13 -0800 Subject: [PATCH 10/35] Add simple scheduler support --- examples/multilora_inference.py | 20 ++++++++++-- vllm/config.py | 18 +++++------ vllm/core/scheduler.py | 55 +++++++++++++++++++++++++++------ vllm/engine/arg_utils.py | 6 ++++ vllm/lora/models.py | 23 +++++++------- vllm/lora/worker_manager.py | 8 ++--- vllm/sequence.py | 8 +++++ vllm/worker/worker.py | 13 +++++--- 8 files changed, 111 insertions(+), 40 deletions(-) diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 65885e534b508..9aa0edc35d32c 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -46,6 +46,21 @@ def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: max_tokens=128, stop_token_ids=[32003]), LoRARequest("sql-lora", 1, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), ] @@ -68,14 +83,15 @@ def process_requests(engine: LLMEngine, for request_output in request_outputs: if request_output.finished: - print(request_output) + print(request_output.lora_request) def initialize_engine() -> LLMEngine: """Initialize the LLMEngine.""" engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", enable_lora=True, - max_num_seqs=32) + max_loras=1, + max_num_seqs=256) return LLMEngine.from_engine_args(engine_args) diff --git a/vllm/config.py b/vllm/config.py index 0b03565a5031a..007f1026affa0 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -354,10 +354,10 @@ def _verify_args(self) -> None: @dataclass class LoRAConfig: max_lora_rank: int + max_loras: int max_cpu_loras: Optional[int] = None lora_dtype: Optional[torch.dtype] = None lora_extra_vocab_size: int = 256 - max_loras: Optional[int] = None def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h @@ -366,6 +366,14 @@ def __post_init__(self): raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " f"{possible_max_ranks}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_num_seqs ({self.max_loras})") def verify_with_model_config(self, model_config: ModelConfig): if self.lora_dtype in (None, "auto"): @@ -380,14 +388,6 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): "max_num_batched_tokens must be <= 65528 when " "LoRA is enabled.") - self.max_loras = scheduler_config.max_num_seqs - if self.max_cpu_loras is None: - self.max_cpu_loras = scheduler_config.max_num_seqs - elif self.max_cpu_loras < scheduler_config.max_num_seqs: - raise ValueError( - f"max_cpu_loras ({self.max_cpu_loras}) must be >= " - f"max_num_seqs ({scheduler_config.max_num_seqs})") - _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index fce3f2acb65f5..e9081f6b7d726 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -77,6 +77,9 @@ def __init__( ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. self.lora_config = lora_config self.prompt_limit = min(self.scheduler_config.max_model_len, @@ -151,14 +154,16 @@ def _schedule(self) -> SchedulerOutputs: # requests in the generation phase. num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. - while self.waiting: - seq_group = self.waiting[0] - + waiting_indices_to_remove = [] + for i, seq_group in enumerate(self.waiting): assert seq_group.num_seqs() == 1, ( "Waiting sequence group should have only one prompt " "sequence.") @@ -170,7 +175,7 @@ def _schedule(self) -> SchedulerOutputs: for seq in seq_group.get_seqs(): seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) - self.waiting.pop(0) + waiting_indices_to_remove.append(i) continue # If the sequence group cannot be allocated, stop. @@ -184,9 +189,18 @@ def _schedule(self) -> SchedulerOutputs: for seq in seq_group.get_seqs(): seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) - self.waiting.pop(0) + waiting_indices_to_remove.append(i) continue + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + continue + # If the number of batched tokens exceeds the limit, stop. new_seq_lens = seq_lens + [num_prompt_tokens] num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) @@ -206,12 +220,17 @@ def _schedule(self) -> SchedulerOutputs: break seq_lens = new_seq_lens - seq_group = self.waiting.pop(0) + waiting_indices_to_remove.append(i) + if lora_int_id > 0: + curr_loras.add(lora_int_id) self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append(seq_group) + for i in reversed(waiting_indices_to_remove): + self.waiting.pop(i) + if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, @@ -260,9 +279,22 @@ def _schedule(self) -> SchedulerOutputs: if not preempted: num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + swapped_indices_to_remove = [] + + for i, seq_group in enumerate(self.swapped): + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + continue - while self.swapped: - seq_group = self.swapped[0] # If the sequence group cannot be swapped in, stop. if not self.block_manager.can_swap_in(seq_group): break @@ -274,12 +306,17 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - seq_group = self.swapped.pop(0) + swapped_indices_to_remove.append(i) + if lora_int_id > 0: + curr_loras.add(lora_int_id) self._swap_in(seq_group, blocks_to_swap_in) self._append_slot(seq_group, blocks_to_copy) num_curr_seqs += num_new_seqs self.running.append(seq_group) + for i in reversed(swapped_indices_to_remove): + self.swapped.pop(i) + # Each sequence in the generation phase only takes one token slot. # Therefore, the number of batched tokens is equal to the number of # sequences in the RUNNING state. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6dc695aaa554c..a8c9d87215737 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -34,6 +34,7 @@ class EngineArgs: tokenizer_revision: Optional[str] = None quantization: Optional[str] = None enable_lora: bool = False + max_loras: int = 1 max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 lora_dtype = 'auto' @@ -191,6 +192,10 @@ def add_cli_args( parser.add_argument('--enable-lora', action='store_true', help='enable lora adapters') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='max number of LoRAs in a single batch') parser.add_argument('--max-lora-rank', type=int, default=EngineArgs.max_lora_rank, @@ -244,6 +249,7 @@ def create_engine_configs( self.max_paddings) lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, max_cpu_loras=self.lora_max_cpu_loras if self.lora_max_cpu_loras > diff --git a/vllm/lora/models.py b/vllm/lora/models.py index ecb29c94c9821..bdf100f52b5c3 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -314,9 +314,9 @@ def __init__( """ self.lora_config = lora_config self.max_num_seqs = max_num_seqs - assert self.capacity >= self.max_num_seqs + assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 - self.lora_id_to_index: List[Optional[int]] = [None] * self._lora_slots + self.lora_id_to_index: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, @@ -353,8 +353,8 @@ def capacity(self) -> int: return self.lora_config.max_cpu_loras @property - def _lora_slots(self) -> int: - return self.max_num_seqs + def lora_slots(self) -> int: + return self.lora_config.max_loras def __len__(self) -> int: return len(self._registered_loras) @@ -421,7 +421,7 @@ def convert_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, indices_len) = convert_mapping(mapping, self.lora_id_to_index, - self._lora_slots + 1, self.vocab_size, + self.lora_slots + 1, self.vocab_size, self.lora_config.lora_extra_vocab_size) self.base_indices[:base_indices.shape[0]].copy_(base_indices) self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) @@ -448,7 +448,7 @@ def get_lora(self, lora_id: int) -> Optional[LoRAModel]: def remove_all_loras(self) -> bool: """Remove all LoRAModels from the manager.""" self._registered_loras.clear() - self.lora_id_to_index = [None] * self._lora_slots + self.lora_id_to_index = [None] * self.lora_slots self._active_loras.clear() def _create_lora_modules(self): @@ -458,16 +458,15 @@ def _create_lora_modules(self): new_module = replace_submodule( self.model, module_name, - from_layer(module, self._lora_slots, self.lora_config, + from_layer(module, self.lora_slots, self.lora_config, self.model.config)) # (yard1): TODO make this more robust if "lm_head" in module_name: sampler_module = self.model.get_submodule("sampler") new_module = replace_submodule( self.model, "sampler", - from_layer_sampler(sampler_module, module, - self._lora_slots, self.lora_config, - self.model.config)) + from_layer_sampler(sampler_module, module, self.lora_slots, + self.lora_config, self.model.config)) self.register_module(module_name, new_module) self._register_packed_modules(module_name) new_module.set_mapping(self.base_indices, self.sampler_indices, @@ -604,7 +603,7 @@ def __init__( self._registered_loras: LoRALRUCache = LoRALRUCache( self.capacity, self.deactivate_lora) self._active_loras: LoRALRUCache = LoRALRUCache( - self.max_num_seqs, self._deactivate_lora) + self.lora_slots, self._deactivate_lora) def list_loras(self) -> Dict[int, LoRAModel]: """List all registered LoRAModels.""" @@ -629,7 +628,7 @@ def activate_lora( lora_id: int, ) -> bool: if lora_id not in self._active_loras and len( - self._active_loras) >= self.max_num_seqs: + self._active_loras) >= self.lora_slots: self._active_loras.remove_oldest() result = super().activate_lora(lora_id) # We always touch to update the LRU cache order diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index be6f4cf0589bd..4b90c6a556285 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -137,7 +137,7 @@ def create_lora_adapter( lora_config=self.lora_config, lora_manager_cls=self._lora_manager_cls, ) - self._lora_manager = lora_manager + self._lora_manager: LoRAModelManager = lora_manager return lora_manager.model def apply_loras(self, lora_requests: List[LoRARequest], @@ -155,7 +155,7 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._lora_manager.max_num_seqs}).") + f"({self._lora_manager.lora_slots}).") new_loras = set(loras_map) loras_to_add = new_loras - loras_that_exist @@ -235,7 +235,7 @@ def create_lora_adapter( lora_config=self.lora_config, max_num_batched_tokens=self.max_num_batched_tokens, ) - self._lora_manager = lora_manager + self._lora_manager: LRUCacheLoRAModelManager = lora_manager return lora_manager.model def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: @@ -247,7 +247,7 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " - f"({self._lora_manager.max_num_seqs}).") + f"({self._lora_manager.lora_slots}).") for lora in loras_map.values(): self.add_lora(lora) diff --git a/vllm/sequence.py b/vllm/sequence.py index 06170ab79d69a..036a697ab3491 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -137,6 +137,10 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -262,6 +266,10 @@ def prompt_token_ids(self) -> List[int]: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).data.prompt_token_ids + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index b18668007028f..4cf04babd299f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -123,8 +123,9 @@ def profile_num_available_blocks( # consumption create dummy lora request copies from the lora request # passed in, which contains a lora from the lora warmup path. dummy_lora_requests = [] + dummy_lora_requests_per_seq = [] if self.lora_config: - for idx in range(max_num_seqs): + for idx in range(self.lora_config.max_loras): lora_id = idx + 1 dummy_lora_request = LoRARequest( lora_id=f"warmup_{lora_id}", @@ -134,6 +135,10 @@ def profile_num_available_blocks( self.lora_manager.add_dummy_lora(dummy_lora_request, rank=LORA_WARMUP_RANK) dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] seqs = [] for group_id in range(max_num_seqs): @@ -146,8 +151,8 @@ def profile_num_available_blocks( seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, - lora_request=dummy_lora_requests[group_id] - if dummy_lora_requests else None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, ) seqs.append(seq) @@ -159,7 +164,7 @@ def profile_num_available_blocks( prepared_lora_requests, ) = self._prepare_inputs(seqs) - if dummy_lora_requests: + if self.lora_config: self.apply_loras(prepared_lora_requests, lora_mapping) # Execute the model. From 1b00e500f1a1500e25840f26ac5c814ce790d0cf Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 15:58:38 -0800 Subject: [PATCH 11/35] Update example --- examples/multilora_inference.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 9aa0edc35d32c..91b8675827b93 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -15,10 +15,11 @@ def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: """Create a list of test prompts with their sampling parameters. - 2 requests for base model, 2 requests for the LoRA. - - In this example, we only use one LoRA adapter. However, we could - specify multiple adapters and use them in the same way. + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + Since we also set `max_loras=1`, the expectation is that the requests + with the second LoRA adapter will be ran after all requests with the + first adapter have finished. """ return [ ("A robot may not injure a human being", @@ -88,9 +89,18 @@ def process_requests(engine: LLMEngine, def initialize_engine() -> LLMEngine: """Initialize the LLMEngine.""" + # max_loras: controls the number of LoRAs that can be used in the same + # batch. Larger numbers will cause higher memory usage, as each LoRA + # slot requires its own preallocated tensor. + # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger + # numbers will cause higher memory usage. If you know that all LoRAs will + # use the same rank, it is recommended to set this as low as possible. + # max_cpu_loras: controls the size of the CPU LoRA cache. engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", enable_lora=True, max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, max_num_seqs=256) return LLMEngine.from_engine_args(engine_args) From 6bda3c369a867a2b6496e884a62a40519c42ca79 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 16:02:50 -0800 Subject: [PATCH 12/35] Fix --- vllm/lora/worker_manager.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 4b90c6a556285..43c016c32193b 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -151,7 +151,7 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request } - if len(loras_map) > self._lora_manager.max_num_seqs: + if len(loras_map) > self._lora_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " @@ -243,7 +243,7 @@ def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: lora_request.lora_int_id: lora_request for lora_request in lora_requests if lora_request } - if len(loras_map) > self._lora_manager.max_num_seqs: + if len(loras_map) > self._lora_manager.lora_slots: raise RuntimeError( f"Number of requested LoRAs ({len(loras_map)}) is greater " "than the number of GPU LoRA slots " From de029618448c2d68c3a1e3e4af2624f0d6710bc8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 29 Nov 2023 16:08:18 -0800 Subject: [PATCH 13/35] Update tests --- tests/lora/conftest.py | 3 ++- tests/lora/test_layers.py | 16 ++++++++++++---- tests/lora/test_llama.py | 4 ++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 263a2bc9d8156..31803b741bd2e 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -124,7 +124,8 @@ def llama_2_7b_engine_extra_embeddings() -> nn.Module: get_model_old = get_model def get_model_patched(model_config, lora_config=None): - return get_model_old(model_config, LoRAConfig(max_lora_rank=8)) + return get_model_old(model_config, + LoRAConfig(max_loras=4, max_lora_rank=8)) with patch("vllm.worker.worker.get_model", get_model_patched): engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index fa6a18e8d93d2..319b33652b61a 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -170,7 +170,9 @@ def create_random_inputs( @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) def test_embeddings(dist_init, num_loras) -> None: - lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + lora_config = LoRAConfig(max_loras=8, + max_lora_rank=8, + lora_dtype=torch.float16) max_loras = 8 def create_random_embedding_layer(): @@ -258,7 +260,9 @@ def create_random_embedding_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: - lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + lora_config = LoRAConfig(max_loras=8, + max_lora_rank=8, + lora_dtype=torch.float16) max_loras = 8 def create_random_embedding_layer(): @@ -495,7 +499,9 @@ def create_random_sampler_layer(): @pytest.mark.parametrize("orientation", ["row", "column"]) def test_linear_parallel(dist_init, num_loras, orientation) -> None: - lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + lora_config = LoRAConfig(max_loras=8, + max_lora_rank=8, + lora_dtype=torch.float16) max_loras = 8 def create_random_linear_parallel_layer(): @@ -589,7 +595,9 @@ def create_random_linear_parallel_layer(): @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [2, 3]) def test_column_parallel_packed(dist_init, num_loras, repeats) -> None: - lora_config = LoRAConfig(max_lora_rank=8, lora_dtype=torch.float16) + lora_config = LoRAConfig(max_loras=8, + max_lora_rank=8, + lora_dtype=torch.float16) max_loras = 8 def create_column_parallel_packed_layer(): diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index 756fc55246092..4760c5cc1e950 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -43,6 +43,7 @@ def test_llama_lora(sql_lora_files, tp_size): llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16, + max_loras=4, tensor_parallel_size=tp_size, worker_use_ray=True) @@ -85,6 +86,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): llm_tp1 = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16, + max_loras=4, tensor_parallel_size=1, worker_use_ray=True) output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) @@ -95,6 +97,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): llm_tp2 = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16, + max_loras=4, tensor_parallel_size=2, worker_use_ray=True) output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) @@ -107,6 +110,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): llm_tp4 = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16, + max_loras=4, tensor_parallel_size=4, worker_use_ray=True) output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) From 849831e300417e4f2df96e3279727f5b6082e460 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 30 Nov 2023 11:47:27 -0800 Subject: [PATCH 14/35] Cleanup --- examples/multilora_inference.py | 2 +- vllm/engine/arg_utils.py | 14 +++++++------- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 91b8675827b93..8fdd243af69ff 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -84,7 +84,7 @@ def process_requests(engine: LLMEngine, for request_output in request_outputs: if request_output.finished: - print(request_output.lora_request) + print(request_output) def initialize_engine() -> LLMEngine: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index cfeb1600ac8c6..0dae1613690e5 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -38,7 +38,7 @@ class EngineArgs: max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 lora_dtype = 'auto' - lora_max_cpu_loras: int = -1 + max_cpu_loras: int = -1 def __post_init__(self): if self.tokenizer is None: @@ -208,12 +208,12 @@ def add_cli_args( type=str, default=EngineArgs.lora_dtype, choices=['auto', 'float16', 'bfloat16', 'float32'], - help='data type for lora') + help='data type for LoRA') parser.add_argument( - '--lora-max-cpu-loras', + '--max-cpu-loras', type=int, - default=EngineArgs.lora_max_cpu_loras, - help=('Maximum number of loras to store in CPU memory. ' + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) return parser @@ -253,8 +253,8 @@ def create_engine_configs( max_loras=self.max_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, - max_cpu_loras=self.lora_max_cpu_loras if self.lora_max_cpu_loras > - 0 else None) if self.enable_lora else None + max_cpu_loras=self.max_cpu_loras + if self.max_cpu_loras > 0 else None) if self.enable_lora else None return model_config, cache_config, parallel_config, scheduler_config, lora_config From 66540339898f4378f3572251099559042944c20d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Mon, 4 Dec 2023 11:18:42 -0800 Subject: [PATCH 15/35] Do not pin memory in WSL --- vllm/lora/models.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index e821ab14cbf62..653d3b924cc3a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -12,7 +12,7 @@ from torch import nn from vllm.config import LoRAConfig -from vllm.utils import LRUCache +from vllm.utils import LRUCache, in_wsl from vllm.lora.layers import LoRALayer, LoRAMapping, from_layer, from_layer_sampler from vllm.lora.lora import LoRA @@ -138,16 +138,21 @@ def _create_dummy_lora(module_name: str, dtype: torch.dtype, device: torch.device, embeddings_tensor_dim: Optional[int] = None) -> "LoRA": - lora_a = torch.zeros([input_dim, rank], dtype=dtype, device=device) - lora_b = torch.zeros([rank, output_dim], dtype=dtype, device=device) + pin_memory = str(device) == "cpu" and not in_wsl() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) embeddings_tensor = torch.rand( - 10, embeddings_tensor_dim, dtype=dtype, - device=device) if embeddings_tensor_dim else None - if str(device) == "cpu": - lora_a = lora_a.pin_memory() - lora_b = lora_b.pin_memory() - if embeddings_tensor is not None: - embeddings_tensor = embeddings_tensor.pin_memory() + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None return LoRA( module_name, rank=rank, @@ -191,6 +196,7 @@ def from_lora_tensors( target_embedding_padding: Optional[int] = None, ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and not in_wsl() loras: Dict[str, LoRA] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) @@ -204,7 +210,7 @@ def from_lora_tensors( lora_embeddings_tensor = embeddings[ EMBEDDING_MODULES[embeddings_module]].to( device=device, dtype=dtype) - if device == "cpu": + if pin_memory: lora_embeddings_tensor = ( lora_embeddings_tensor.pin_memory()) loras[module_name] = LoRA(module_name, rank, lora_alpha, None, @@ -212,7 +218,7 @@ def from_lora_tensors( if is_lora_a: loras[module_name].lora_a = tensor.to(device=device, dtype=dtype).t() - if device == "cpu": + if pin_memory: loras[module_name].lora_a = loras[ module_name].lora_a.pin_memory() else: @@ -226,7 +232,7 @@ def from_lora_tensors( addition = target_embedding_padding - lora_b.shape[1] loras[module_name].lora_b = torch.nn.functional.pad( lora_b, (0, addition)) - if device == "cpu": + if pin_memory: loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() From cf633a7afda862d418578bba349fff6899228e03 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 5 Dec 2023 11:48:11 -0800 Subject: [PATCH 16/35] Raise error on unsupported model --- vllm/model_executor/model_loader.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 0cd890615c918..85cfa9228c136 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -92,6 +92,12 @@ def get_model(model_config: ModelConfig, if getattr(model_class, "supports_lora", True): model = model_class(model_config.hf_config, linear_method, lora_config) + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") else: model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": From 65d154282c44bcecda8919c800afc40e8d9ab25e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 5 Dec 2023 12:20:41 -0800 Subject: [PATCH 17/35] Support more vocab sizes --- csrc/punica/bgmv/bgmv_config.h | 4 +++ tests/lora/test_punica.py | 27 ++------------ vllm/config.py | 9 ++++- vllm/lora/layers.py | 36 +++++++++++++------ vllm/lora/lora.py | 5 +++ vllm/lora/models.py | 5 +++ vllm/lora/worker_manager.py | 5 +++ .../layers/vocab_parallel_embedding.py | 18 +++++++--- vllm/model_executor/models/llama.py | 12 +++++-- vllm/model_executor/models/mistral.py | 12 +++++-- 10 files changed, 86 insertions(+), 47 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index da6e6a611ecaa..ce2a2112a3d91 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -41,8 +41,12 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ + f(in_T, out_T, W_T, narrow, 32512) \ + f(in_T, out_T, W_T, narrow, 32768) \ + f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ +// Keep above in sync with vllm/lora/layers::LoRASampler // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 26a7d47933309..f603b06cdb565 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -43,30 +43,9 @@ def _lora_ref_impl( H1 = H2 = [ - 128, - 256, - 512, - 1024, - 1280, - 2048, - 2560, - 2752, - 3072, - 3456, - 3584, - 4096, - 5120, - 5504, - 6912, - 7168, - 8192, - 9216, - 10240, - 11008, - 13824, - 14336, - 32000, - 32256, + 128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, + 5504, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000, 32256, + 32512, 32768, 33024 ] SEED = [0xabcdabcd987] diff --git a/vllm/config.py b/vllm/config.py index b0a69edc900b1..7f4add75f96a2 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,4 @@ -from typing import Optional, Union +from typing import Optional, Union, ClassVar from dataclasses import dataclass import os @@ -364,14 +364,21 @@ class LoRAConfig: max_cpu_loras: Optional[int] = None lora_dtype: Optional[torch.dtype] = None lora_extra_vocab_size: int = 256 + # This is a constant. + lora_vocab_padding_size: ClassVar[int] = 256 def __post_init__(self): # Keep this in sync with csrc/punica/bgmv/bgmv_config.h possible_max_ranks = (8, 16, 32, 64) + possible_lora_extra_vocab_size = (0, 256, 512) if self.max_lora_rank not in possible_max_ranks: raise ValueError( f"max_lora_rank ({self.max_lora_rank}) must be one of " f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") if self.max_loras < 1: raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") if self.max_cpu_loras is None: diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6ba8b0585847d..f03ae78ee8751 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,4 +1,5 @@ # pylint: disable=unused-argument +import math from dataclasses import dataclass from typing import TYPE_CHECKING, List, Optional, Tuple @@ -283,12 +284,12 @@ def set_lora( if self.embeddings_slice is not None: # TODO(yard1): Optimize this copy, we don't need to copy # everything, just the modified part - self.embeddings_weights.copy_( - self.embeddings_tensors.view( - self.embeddings_tensors.shape[0] * - self.embeddings_tensors.shape[1], - self.embeddings_tensors.shape[2]) - [self.embeddings_slice[0]:self.embeddings_slice[1]]) + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2] + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) def set_mapping( self, @@ -856,6 +857,11 @@ def create_lora_weights( lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, ) -> None: + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + if 32000 < self.base_layer.vocab_size > 33024: + raise ValueError( + "When using LoRA, vocab size must be 32000 >= vocab_size <= 33024" + ) self.lora_a_stacked = torch.zeros( ( max_loras, @@ -870,7 +876,10 @@ def create_lora_weights( ( max_loras, 1, - self.base_layer.vocab_size, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -933,8 +942,6 @@ def _get_logits( if embedding_bias is not None: logits += embedding_bias logits = tensor_model_parallel_all_gather(logits) - # Remove paddings in vocab (if any). - logits = logits[:, :self.base_layer.vocab_size] lora_logits = torch.empty( self.embeddings_tensors.shape[0] + 1, @@ -948,8 +955,7 @@ def _get_logits( out=lora_logits[:-1]) lora_logits[-1] = float("-inf") lora_logits = lora_logits.mT - - logits[:, self.base_layer.org_vocab_size:] = (lora_logits.reshape( + lora_logits = (lora_logits.reshape( lora_logits.shape[0] * lora_logits.shape[1], lora_logits.shape[2], ).index_select(0, @@ -957,6 +963,10 @@ def _get_logits( nan=float("-inf"), posinf=float("inf"), neginf=float("-inf"))) + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + _apply_lora( hidden_states, self.lora_a_stacked, @@ -964,6 +974,10 @@ def _get_logits( self.indices[:self.indices_len[1]], logits, ) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + return logits def forward(self, *args, **kwargs): diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 042a98597ab26..11006c4e1a1c9 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -70,6 +70,11 @@ def output_dim(self) -> int: def is_packed(self) -> bool: return False + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + class PackedLoRA(LoRA): """LoRA used for packed layers (eg. qkv_proj).""" diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 653d3b924cc3a..042b899763f10 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -178,6 +178,11 @@ def __init__( self.rank = rank self.loras: Dict[str, LoRA] = loras + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + def get_lora(self, module_name: str) -> Optional[LoRA]: """Get LoRA for a given module by name""" return self.loras.get(module_name, None) diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 43c016c32193b..363b7770be178 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -184,6 +184,11 @@ def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: raise ValueError( f"LoRA rank {lora.rank} is greater than max_lora_rank " f"{self.lora_config.max_lora_rank}.") + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} is greater than " + f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}." + ) return lora def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index 9e4ac26e73d00..9c5fb890251ed 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -13,8 +13,11 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs +DEFAULT_VOCAB_PADDING_SIZE = 64 -def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int: + +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to @@ -44,19 +47,22 @@ class VocabParallelEmbedding(torch.nn.Module): embedding_dim: size of hidden state. params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. """ def __init__(self, num_embeddings: int, embedding_dim: int, params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None): + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): super().__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings self.org_vocab_size = org_num_embeddings or num_embeddings - self.num_embeddings_padded = pad_vocab_size(num_embeddings) + self.num_embeddings_padded = pad_vocab_size(num_embeddings, + padding_size) self.embedding_dim = embedding_dim if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -118,6 +124,7 @@ class ParallelLMHead(VocabParallelEmbedding): bias: whether to use bias. params_dtype: type of the parameters. org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. """ def __init__(self, @@ -125,9 +132,10 @@ def __init__(self, embedding_dim: int, bias: bool = False, params_dtype: Optional[torch.dtype] = None, - org_num_embeddings: Optional[int] = None): + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): super().__init__(num_embeddings, embedding_dim, params_dtype, - org_num_embeddings) + org_num_embeddings, padding_size) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 342f940dbb92c..240f3ad57f655 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -291,9 +291,15 @@ def __init__( unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead(unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size) + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + # We need bigger padding if using lora for kernel + # compatibility + padding_size=64 + if not lora_config else lora_config.lora_vocab_padding_size, + ) self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 3547a72dc8558..7e2ee4a721e67 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -289,9 +289,15 @@ def __init__( unpadded_vocab_size = config.vocab_size if lora_config: unpadded_vocab_size += lora_config.lora_extra_vocab_size - self.lm_head = ParallelLMHead(unpadded_vocab_size, - config.hidden_size, - org_num_embeddings=config.vocab_size) + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + # We need bigger padding if using lora for kernel + # compatibility + padding_size=64 + if not lora_config else lora_config.lora_vocab_padding_size, + ) self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( From 008e92d51044e978891dd92f842c11cc09771378 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 6 Dec 2023 14:43:05 -0800 Subject: [PATCH 18/35] Update vllm/transformers_utils/tokenizer.py --- vllm/transformers_utils/tokenizer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 983227acce41e..2cfa417679dda 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -118,10 +118,6 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.lora_tokenizers = LRUCache(capacity=max_num_seqs) else: self.lora_tokenizers = None - - def ping(self): - return True - def encode(self, prompt: str, request_id: Optional[str] = None, From c328e587e8094a49f4b7abb0a797b9072ca03348 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 6 Dec 2023 14:51:08 -0800 Subject: [PATCH 19/35] Update vllm/transformers_utils/tokenizer.py --- vllm/transformers_utils/tokenizer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2cfa417679dda..695cb893e430e 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -118,6 +118,7 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, self.lora_tokenizers = LRUCache(capacity=max_num_seqs) else: self.lora_tokenizers = None + def encode(self, prompt: str, request_id: Optional[str] = None, From 8566144a01d5f572860f347a6fde8a627d5e3bc4 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Dec 2023 10:59:37 -0800 Subject: [PATCH 20/35] Reuse code --- vllm/lora/models.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 042b899763f10..1bd10abaa778a 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -622,14 +622,8 @@ def list_loras(self) -> Dict[int, LoRAModel]: def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" - was_added = False - if lora.id not in self._registered_loras: - was_added = True - logger.debug(f"Adding LoRA. Model id: {lora.id}, " - f"int id: {lora.id}") - self._create_merged_loras_inplace(lora) - self._registered_loras[lora.id] = lora - else: + was_added = super().add_lora(lora) + if not was_added: # We always touch to update the LRU cache order self._registered_loras.touch(lora.id) return was_added From 2d72ae560c46332987ee052cdfd7d96ee8fe6db1 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 13 Dec 2023 11:13:19 -0800 Subject: [PATCH 21/35] Naming --- vllm/lora/models.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 1bd10abaa778a..2f7c6154c7a24 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -52,14 +52,14 @@ def convert_mapping( - mapping: LoRAMapping, lora_id_to_index: List[Optional[int]], + mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], max_loras: int, vocab_size: int, extra_vocab_size: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: """Converts LoRAMapping to index tensors. Args: mapping: LoRAMapping mapping rows in a batch to LoRA ids. - lora_id_to_index: List mapping LoRA ids to LoRA indices. + lora_index_to_id: List mapping LoRA ids to LoRA indices. max_loras: Maximum number of LoRAs. vocab_size: Model vocab size. extra_vocab_size: Extra vocab size each LoRA can have. @@ -86,13 +86,13 @@ def convert_mapping( embedding_indices = indices.copy() lora_indices = indices.copy() prompt_mapping = [ - lora_id_to_index.index(x) if x > 0 else -1 + lora_index_to_id.index(x) if x > 0 else -1 for x in mapping.prompt_mapping ] lora_idx = None for i in range(len(indices)): # TODO index can be slow. optimize - lora_idx = (lora_id_to_index.index(indices[i]) + lora_idx = (lora_index_to_id.index(indices[i]) if indices[i] > 0 else -1) embedding_indices[i] = lora_idx if indices[i] > 0 else 0 indices[i] = i @@ -327,7 +327,7 @@ def __init__( self.max_num_seqs = max_num_seqs assert self.capacity >= self.lora_slots self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 - self.lora_id_to_index: List[Optional[int]] = [None] * self.lora_slots + self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots self.vocab_size = vocab_size self.base_indices = torch.empty(self.max_num_batched_tokens, dtype=torch.long, @@ -377,7 +377,7 @@ def activate_lora( if lora_id in self._active_loras: return False first_free_slot = next( - ((i, lora_id) for i, lora_id in enumerate(self.lora_id_to_index) + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) if lora_id is None), None) if first_free_slot is None: raise ValueError("No free lora slots") @@ -386,7 +386,7 @@ def activate_lora( lora_model = self._registered_loras[lora_id] logger.debug( f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") - self.lora_id_to_index[index] = lora_model.id + self.lora_index_to_id[index] = lora_model.id for module_name, module in self.modules.items(): module_lora = lora_model.get_lora(module_name) if module_lora: @@ -399,8 +399,8 @@ def activate_lora( def _deactivate_lora(self, lora_id: int): try: - index = self.lora_id_to_index.index(lora_id) - self.lora_id_to_index[index] = None + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None except ValueError: pass @@ -431,7 +431,7 @@ def remove_lora(self, lora_id: int) -> bool: def convert_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, - indices_len) = convert_mapping(mapping, self.lora_id_to_index, + indices_len) = convert_mapping(mapping, self.lora_index_to_id, self.lora_slots + 1, self.vocab_size, self.lora_config.lora_extra_vocab_size) self.base_indices[:base_indices.shape[0]].copy_(base_indices) @@ -459,7 +459,7 @@ def get_lora(self, lora_id: int) -> Optional[LoRAModel]: def remove_all_loras(self) -> bool: """Remove all LoRAModels from the manager.""" self._registered_loras.clear() - self.lora_id_to_index = [None] * self.lora_slots + self.lora_index_to_id = [None] * self.lora_slots self._active_loras.clear() def _create_lora_modules(self): From 6640a2e65d9a685be637707e1f58351c6012616b Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 21 Dec 2023 17:01:12 -0800 Subject: [PATCH 22/35] Apply feedback from code review --- csrc/punica/bgmv/bgmv_all.cu | 2 +- csrc/punica/bgmv/bgmv_config.h | 2 +- tests/lora/test_layers.py | 43 +++++----- tests/lora/test_lora.py | 14 ++-- tests/lora/test_lora_manager.py | 52 ++++++------ tests/lora/test_worker.py | 8 +- tests/lora/utils.py | 12 +-- vllm/core/scheduler.py | 7 +- vllm/engine/arg_utils.py | 31 ++++--- vllm/engine/llm_engine.py | 34 ++++---- vllm/lora/layers.py | 121 +++++++++------------------- vllm/lora/lora.py | 89 +++++++++++++------- vllm/lora/models.py | 99 +++++++++-------------- vllm/lora/worker_manager.py | 60 +++----------- vllm/model_executor/model_loader.py | 2 +- vllm/worker/model_runner.py | 20 ++--- vllm/worker/worker.py | 1 - 17 files changed, 261 insertions(+), 336 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu index bc86416701f13..2502a67e3c813 100644 --- a/csrc/punica/bgmv/bgmv_all.cu +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -18,4 +18,4 @@ FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) -FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) \ No newline at end of file +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index ce2a2112a3d91..ced0397dab216 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -46,7 +46,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 33024) \ f(in_T, out_T, W_T, narrow, 36864) \ f(in_T, out_T, W_T, narrow, 49152) \ -// Keep above in sync with vllm/lora/layers::LoRASampler +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 1020d2cd684f1..71c671132205a 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,16 +8,16 @@ import torch.nn.functional as F from vllm.lora.layers import ( - LoRAColumnParallelLinear, - LoRAMergedColumnParallelLinear2Slice, - LoRAQKVParallelLinear, - LoRAVocabParallelEmbedding, - LoRARowParallelLinear, - LoRASampler, + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + VocabParallelEmbeddingWithLoRA, + RowParallelLinearWithLoRA, + SamplerWithLoRA, LoRAMapping, - LoRALayer, + BaseLayerWithLoRA, ) -from vllm.lora.models import LoRA, convert_mapping +from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights from vllm.config import LoRAConfig from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -66,11 +66,11 @@ def get_random_id_to_index(num_loras: int, def populate_loras( id_to_index: List[Optional[int]], - layer: LoRALayer, + layer: BaseLayerWithLoRA, layer_weights: torch.Tensor, generate_embeddings_tensor: int = 0, repeats: int = 1, -) -> Tuple[Dict[int, LoRA], Dict[int, List[LoRA]]]: +) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: """This method populates the lora layers with lora weights. Args: @@ -89,12 +89,12 @@ def populate_loras( # Dictionary that maps the lora ID to the # corresponding lora weights. - lora_dict: Dict[int, LoRA] = dict() + lora_dict: Dict[int, LoRALayerWeights] = dict() # Dictionary that maps the lora ID to the # corresponding subloras. Only useful when # repeats > 1. - sublora_dict: Dict[int, List[LoRA]] = dict() + sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() for slot_idx, lora_id in enumerate(id_to_index): if lora_id is not None: @@ -111,7 +111,8 @@ def populate_loras( sublora.optimize() subloras.append(sublora) - lora = LoRA.pack(subloras) if repeats > 1 else subloras[0] + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] layer.set_lora( slot_idx, @@ -179,7 +180,7 @@ def create_random_embedding_layer(): embedding = VocabParallelEmbedding(512, 256) embedding.weight.data = torch.rand_like(embedding.weight.data) embedding.weight.data[512:, :] = 0 - lora_embedding = LoRAVocabParallelEmbedding(embedding) + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) lora_embedding.create_lora_weights(max_loras, lora_config) return embedding, lora_embedding @@ -277,7 +278,7 @@ def create_random_embedding_layer(): expanded_embedding.weight.data[:512, :] = embedding_data # We need to deepcopy the embedding as it will be modifed # in place - lora_embedding = LoRAVocabParallelEmbedding( + lora_embedding = VocabParallelEmbeddingWithLoRA( deepcopy(expanded_embedding)) lora_embedding.create_lora_weights(max_loras, lora_config) @@ -400,8 +401,8 @@ def create_random_sampler_layer(): linear.weight.data = torch.rand_like(linear.weight.data) linear.weight.data[:, 32000:] = 0 sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) - lora_sampler = LoRASampler(sampler, 1024, linear.weight.dtype, - linear.weight.device) + lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype, + linear.weight.device) lora_sampler.create_lora_weights(max_loras, lora_config) return linear, sampler, lora_sampler @@ -510,11 +511,11 @@ def create_random_linear_parallel_layer(): if orientation == "row": linear = RowParallelLinear(4096, 4096, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = LoRARowParallelLinear(linear) + lora_linear = RowParallelLinearWithLoRA(linear) else: linear = ColumnParallelLinear(4096, 4096, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = LoRAColumnParallelLinear(linear) + lora_linear = ColumnParallelLinearWithLoRA(linear) lora_linear.create_lora_weights(max_loras, lora_config) return linear, lora_linear @@ -608,11 +609,11 @@ def create_column_parallel_packed_layer(): linear = MergedColumnParallelLinear(4096, [4096] * repeats, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = LoRAMergedColumnParallelLinear2Slice(linear) + lora_linear = MergedColumnParallelLinearWithLoRA(linear) else: linear = QKVParallelLinear(4096, 64, 32, bias=False) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = LoRAQKVParallelLinear(linear) + lora_linear = QKVParallelLinearWithLora(linear) @dataclass class FakeConfig: diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py index b86f7a480e749..1b972cc53f24d 100644 --- a/tests/lora/test_lora.py +++ b/tests/lora/test_lora.py @@ -1,7 +1,7 @@ import pytest import torch -from vllm.lora.layers import _apply_lora, _apply_lora_packed_2slice, _apply_lora_packed_3slice +from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice from .utils import DummyLoRAManager @@ -122,19 +122,19 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T output = torch.zeros(k, m, device="cuda", dtype=dtype) - _apply_lora_packed_2slice( + _apply_lora_packed_nslice( input, lora_a_stacks, lora_b_stacks, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, m // 2) + device="cuda"), output, (m // 2, )) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_2slice(input, lora_a_stacks, lora_b_stacks, + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, torch.full((len(input), ), -1, device="cuda"), - output, m // 2) + output, (m // 2, )) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() @@ -206,7 +206,7 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) - _apply_lora_packed_3slice( + _apply_lora_packed_nslice( input, lora_a_stacks, lora_b_stacks, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), @@ -216,7 +216,7 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: assert torch.allclose(expected, output, rtol=rtol, atol=atol) output[:] = 0 - _apply_lora_packed_3slice(input, lora_a_stacks, lora_b_stacks, + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, torch.full((len(input), ), -1, device="cuda"), output, (qkv[0], qkv[1])) assert torch.allclose(torch.zeros_like(output), output) diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index de7b245ad4e79..9c52058ff9a51 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -7,9 +7,10 @@ from torch import nn from vllm.config import LoRAConfig -from vllm.lora.layers import (LoRAColumnParallelLinear, LoRARowParallelLinear, - LoRAMergedColumnParallelLinear2Slice) -from vllm.lora.lora import LoRA, PackedLoRA +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA) +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager, LRUCacheLoRAModelManager, LoRAMapping) from vllm.lora.request import LoRARequest @@ -54,7 +55,7 @@ def create_lora(lora_id: int, model: nn.Module, loras = {} for name in sub_modules: w = model.get_submodule(name).weight - loras[name] = LoRA( + loras[name] = LoRALayerWeights( name, 8, 16, @@ -76,7 +77,7 @@ def create_packed_lora( for replaced_module_name in replaced_module_names: if replaced_module_name == empty_replaced_module_name: continue - loras[replaced_module_name] = LoRA( + loras[replaced_module_name] = LoRALayerWeights( replaced_module_name, 8, 16, @@ -99,12 +100,13 @@ def test_replace_submodules(dist_init, dummy_model): lora_target_modules=["dense1", "layer1.dense2"]) model = manager.model - assert isinstance(model.get_submodule("dense1"), LoRAColumnParallelLinear) + assert isinstance(model.get_submodule("dense1"), + ColumnParallelLinearWithLoRA) assert isinstance(model.get_submodule("layer1.dense1"), - LoRAColumnParallelLinear) + ColumnParallelLinearWithLoRA) assert isinstance(model.get_submodule("dense2"), RowParallelLinear) assert isinstance(model.get_submodule("layer1.dense2"), - LoRARowParallelLinear) + RowParallelLinearWithLoRA) def test_lora_model_manager(dist_init, dummy_model): @@ -289,10 +291,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, worker_lora_manager = LRUCacheWorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, torch.device("cuda")) - worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) @@ -300,7 +302,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) @@ -311,7 +313,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 3 assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) @@ -322,7 +324,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) @@ -333,7 +335,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) @@ -346,7 +348,7 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -362,10 +364,10 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, worker_lora_manager = WorkerLoRAManager( 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, torch.device("cuda")) - worker_lora_manager.create_lora_adapter(llama_2_7b_model_extra_embeddings) + worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) mapping = LoRAMapping([], []) - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files) ], mapping) @@ -373,7 +375,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("3", 3, sql_lora_files), LoRARequest("4", 4, sql_lora_files) @@ -383,7 +385,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 3 assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 4 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("2", 2, sql_lora_files), LoRARequest("5", 5, sql_lora_files) @@ -393,7 +395,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files), LoRARequest("1", 1, sql_lora_files) @@ -403,7 +405,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, assert worker_lora_manager._lora_manager.lora_id_to_index[1] is None assert worker_lora_manager._lora_manager.lora_id_to_index[2] is None - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("6", 6, sql_lora_files), LoRARequest("7", 7, sql_lora_files), LoRARequest("8", 8, sql_lora_files) @@ -415,7 +417,7 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, # Over capacity with pytest.raises(RuntimeError): - worker_lora_manager.apply_loras([ + worker_lora_manager.set_active_loras([ LoRARequest("10", 10, sql_lora_files), LoRARequest("11", 11, sql_lora_files), LoRARequest("12", 12, sql_lora_files), @@ -446,12 +448,12 @@ def test_packed_loras(dist_init, dummy_model_gate_up): model = manager.model assert isinstance(model.get_submodule("gate_up_proj"), - LoRAMergedColumnParallelLinear2Slice) + MergedColumnParallelLinearWithLoRA) assert manager.add_lora(model_lora) assert manager.add_lora(model_lora1) packed_lora = model_lora.get_lora("gate_up_proj") - assert packed_lora and isinstance(packed_lora, PackedLoRA) + assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) assert torch.allclose(packed_lora.lora_a[0], model_lora.get_lora("gate_proj").lora_a) @@ -463,7 +465,7 @@ def test_packed_loras(dist_init, dummy_model_gate_up): model_lora.get_lora("up_proj").lora_b) packed_lora1 = model_lora1.get_lora("gate_up_proj") - assert packed_lora1 and isinstance(packed_lora1, PackedLoRA) + assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) assert packed_lora1.lora_a[0] is None assert packed_lora1.lora_b[0] is None diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index abc8babd55e93..126d910f53ab3 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -30,7 +30,7 @@ def test_worker_apply_lora(sql_lora_files): worker.init_model() worker.load_model() - worker.model_runner.apply_loras([], LoRAMapping([], [])) + worker.model_runner.set_active_loras([], LoRAMapping([], [])) assert worker.list_loras() == set() n_loras = 32 @@ -38,7 +38,7 @@ def test_worker_apply_lora(sql_lora_files): LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) ] - worker.model_runner.apply_loras(lora_requests, LoRAMapping([], [])) + worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], [])) assert worker.list_loras() == { lora_request.lora_int_id for lora_request in lora_requests @@ -50,8 +50,8 @@ def test_worker_apply_lora(sql_lora_files): k=random.randint(1, n_loras)) random.shuffle(iter_lora_requests) iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] - worker.model_runner.apply_loras(iter_lora_requests, LoRAMapping([], - [])) + worker.model_runner.set_active_loras(iter_lora_requests, + LoRAMapping([], [])) assert worker.list_loras().issuperset( {lora_request.lora_int_id for lora_request in iter_lora_requests}) diff --git a/tests/lora/utils.py b/tests/lora/utils.py index 072a0d957758b..280e0f2043e68 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -2,7 +2,7 @@ import torch -from vllm.lora.lora import LoRA +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights class DummyLoRAManager: @@ -11,10 +11,10 @@ def __init__(self): super().__init__() self._loras = {} - def set_module_lora(self, module_name: str, lora: LoRA): + def set_module_lora(self, module_name: str, lora: LoRALayerWeights): self._loras[module_name] = lora - def get_module_lora(self, module_name: str) -> Optional[LoRA]: + def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]: return self._loras.get(module_name, None) def init_random_lora(self, @@ -22,7 +22,7 @@ def init_random_lora(self, weight: torch.Tensor, rank: int = 8, generate_embeddings_tensor: int = 0): - lora = LoRA( + lora = LoRALayerWeights( module_name, rank=rank, lora_alpha=1, @@ -49,7 +49,7 @@ def init_lora(self, rank=8, noop=False, embeddings_tensor=None): - lora = LoRA( + lora = LoRALayerWeights( module_name, rank=rank, lora_alpha=1, @@ -83,6 +83,6 @@ def init_packed_lora( noop=i in noop_lora_index, ) base_loras.append(base_lora) - packed_lora = LoRA.pack(base_loras) + packed_lora = PackedLoRALayerWeights.pack(base_loras) self.set_module_lora(module_name, packed_lora) return packed_lora diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 14647958cbd29..fc5ee185c4045 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -37,7 +37,6 @@ def __init__( blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], ignored_seq_groups: List[SequenceGroup], - lora_enabled: bool = False, ) -> None: self.scheduled_seq_groups = scheduled_seq_groups self.prompt_run = prompt_run @@ -49,8 +48,8 @@ def __init__( assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups - if lora_enabled: - self.num_loras = len(set(self.lora_requests)) + self.num_loras = len(self.lora_requests) + if self.num_loras > 0: self._sort_by_lora_ids() def is_empty(self) -> bool: @@ -243,7 +242,6 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=ignored_seq_groups, - lora_enabled=self.lora_enabled, ) return scheduler_outputs @@ -334,7 +332,6 @@ def _schedule(self) -> SchedulerOutputs: blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, ignored_seq_groups=[], - lora_enabled=self.lora_enabled, ) return scheduler_outputs diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 23f86e7637037..62e5aa5257914 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -40,7 +40,7 @@ class EngineArgs: max_lora_rank: int = 16 lora_extra_vocab_size: int = 256 lora_dtype = 'auto' - max_cpu_loras: int = -1 + max_cpu_loras: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -211,24 +211,29 @@ def add_cli_args( # LoRA related configs parser.add_argument('--enable-lora', action='store_true', - help='enable lora adapters') + help='If True, enable handling of LoRA adapters.') parser.add_argument('--max-loras', type=int, default=EngineArgs.max_loras, - help='max number of LoRAs in a single batch') + help='Max number of LoRAs in a single batch.') parser.add_argument('--max-lora-rank', type=int, default=EngineArgs.max_lora_rank, - help='max LoRA rank') - parser.add_argument('--lora-extra-vocab-size', - type=int, - default=EngineArgs.lora_extra_vocab_size, - help='LoRA extra vocab size') - parser.add_argument('--lora-dtype', - type=str, - default=EngineArgs.lora_dtype, - choices=['auto', 'float16', 'bfloat16', 'float32'], - help='data type for LoRA') + help='Max LoRA rank.') + parser.add_argument( + '--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help=('Maximum size of extra vocabulary that can be ' + 'present in a LoRA adapter (added to the base ' + 'model vocabulary).')) + parser.add_argument( + '--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help=('Data type for LoRA. If auto, will default to ' + 'base model dtype.')) parser.add_argument( '--max-cpu-loras', type=int, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 0bfd3276d8de6..12dea4e842b44 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -125,6 +125,9 @@ def __init__( # List of (timestamp, num_tokens) self.num_generation_tokens: List[Tuple[float, int]] = [] + def get_tokenizer_for_seq(self, sequence: Sequence): + return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + def _init_workers(self, distributed_init_method: str): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -154,7 +157,7 @@ def _init_workers(self, distributed_init_method: str): max_parallel_loading_workers, ) - def _init_tokenizer(self, **kwargs): + def _init_tokenizer(self, **tokenizer_init_kwargs): init_kwargs = dict( enable_lora=bool(self.lora_config), max_num_seqs=self.scheduler_config.max_num_seqs, @@ -162,7 +165,7 @@ def _init_tokenizer(self, **kwargs): tokenizer_mode=self.model_config.tokenizer_mode, trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) - init_kwargs.update(kwargs) + init_kwargs.update(tokenizer_init_kwargs) self.tokenizer: MultiLoRATokenizer = MultiLoRATokenizer( self.model_config.tokenizer, **init_kwargs) @@ -389,13 +392,13 @@ def _check_beam_search_early_stopping( current_worst_score = (current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer( - current_worst_seq.lora_request).eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + current_worst_seq).eos_token_id)) if early_stopping is False: highest_attainable_score = (best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer( - best_running_seq.lora_request).eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -409,8 +412,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer( - best_running_seq.lora_request).eos_token_id, + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id, seq_len=max_possible_length)) else: # Otherwise, beam search will prefer shorter sequences. The @@ -419,8 +422,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer( - best_running_seq.lora_request).eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -511,8 +514,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the finished sequences by their scores. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request - ).eos_token_id), + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: @@ -540,8 +542,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the running sequences by their scores. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.get_lora_tokenizer(x[0].lora_request - ).eos_token_id), + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), reverse=True) # Check if we can stop the beam search. @@ -721,7 +722,7 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( - self.tokenizer.get_lora_tokenizer(seq.lora_request), + self.get_tokenizer_for_seq(seq), all_input_ids=seq.get_token_ids(), prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, @@ -764,8 +765,7 @@ def _check_stop(self, seq: Sequence, # Check if the sequence has generated the EOS token. if ((not sampling_params.ignore_eos) and seq.get_last_token_id() - == self.tokenizer.get_lora_tokenizer( - seq.lora_request).eos_token_id): + == self.get_tokenizer_for_seq(seq).eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index f03ae78ee8751..252909c859628 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -51,64 +51,20 @@ def _apply_lora( output: (batch_size, output_dim) """ org_output = output - if x.ndim == 3: - x = x.view(x.shape[0] * x.shape[1], -1) - if output.ndim == 3: - output = output.view(output.shape[0] * output.shape[1], -1) + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) return output.view_as(org_output) -def _apply_lora_packed_2slice( - x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, torch.Tensor], - indices: torch.Tensor, - output: torch.Tensor, - output_dim: int, -): - """Applies lora to each input. - - This method applies all loras to each input. It uses the - indices vector to determine which lora yields the - correct output. An index of -1 means no lora should be - applied. This method adds the final lora results to the - output. - - This method is used for layers that are composed of 2 sublayers - (slices) packed together (eg. gate_proj + up_proj -> - gate_up_proj). - - Both slices must have the same size (output_dim), meaning the output - tensor will have size output_dim*2. - - Input shapes: - x: (batch_size, hidden_dim) - lora_a_stacked: 2 element tuple of (num_loras, lora_rank, hidden_dim) - lora_b_stacked: 2 element tuple of (num_loras, output_dim, lora_rank) - indices: (batch_size) - output: (batch_size, output_dim*2) - output_dim: scalar - """ - org_output = output - if x.ndim == 3: - x = x.view(x.shape[0] * x.shape[1], -1) - if output.ndim == 3: - output = output.view(output.shape[0] * output.shape[1], -1) - add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, - 1.0, 0, output_dim) - add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, - 1.0, output_dim, output_dim) - return output.view_as(org_output) - - -def _apply_lora_packed_3slice( +def _apply_lora_packed_nslice( x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], indices: torch.Tensor, output: torch.Tensor, - output_slices: Tuple[int, int], + output_slices: Tuple[int, ...], ): """Applies lora to each input. @@ -118,10 +74,8 @@ def _apply_lora_packed_3slice( applied. This method adds the final lora results to the output. - This method is used for layers that are composed of 3 sublayers - (slices) packed together (attention projection). The - first slice (Q) may have different size from the two subsequent - slices (K, V). + This method is used for layers that are composed of multiple sublayers + (slices) packed together. Input shapes: x: (batch_size, hidden_dim) @@ -129,13 +83,12 @@ def _apply_lora_packed_3slice( lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) indices: (batch_size) output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: 2 element tuple of (q_slice_size, kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), where n is number of slices """ org_output = output - if x.ndim == 3: - x = x.view(x.shape[0] * x.shape[1], -1) - if output.ndim == 3: - output = output.view(output.shape[0] * output.shape[1], -1) + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, 1.0, 0, output_slices[0]) add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, @@ -147,20 +100,17 @@ def _apply_lora_packed_3slice( @dataclass class LoRAMapping: + # Per every token in input_ids: index_mapping: Tuple[int, ...] + # Per sampled token: prompt_mapping: Tuple[int, ...] - def __eq__(self, __value: object) -> bool: - return (isinstance(__value, self.__class__) - and self.prompt_mapping == __value.prompt_mapping - and self.index_mapping == __value.index_mapping) - def __post_init__(self): self.index_mapping = tuple(self.index_mapping) self.prompt_mapping = tuple(self.prompt_mapping) -class LoRALayer(nn.Module): +class BaseLayerWithLoRA(nn.Module): def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, model_config: PretrainedConfig) -> None: @@ -193,7 +143,7 @@ def set_mapping( ... -class LoRAVocabParallelEmbedding(LoRALayer): +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: VocabParallelEmbedding) -> None: super().__init__() @@ -327,7 +277,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return full_output.view_as(full_output_org) -class LoRAColumnParallelLinear(LoRALayer): +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() @@ -432,7 +382,7 @@ def linear_weights(self): return self.base_layer.linear_weights -class LoRAMergedColumnParallelLinear2Slice(LoRAColumnParallelLinear): +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 2 sublayers (slices) packed together (eg. gate_proj + up_proj -> gate_up_proj). @@ -523,18 +473,18 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( self.base_layer.linear_weights, x, bias) - _apply_lora_packed_2slice( + _apply_lora_packed_nslice( x, self.lora_a_stacked, self.lora_b_stacked, self.indices[:self.indices_len[0]], output, - self.output_dim, + (self.output_dim, ), ) return output -class LoRAQKVParallelLinear(LoRAColumnParallelLinear): +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): """ColumnParallelLinear layer that is composed of 3 sublayers (slices) packed together in qkv proj fashion (q_proj + k_proj + v_proj -> qkv_proj). @@ -687,7 +637,7 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.linear_method.apply_weights( self.base_layer.linear_weights, x, bias) - _apply_lora_packed_3slice( + _apply_lora_packed_nslice( x, self.lora_a_stacked, self.lora_b_stacked, @@ -698,7 +648,7 @@ def apply_weights(self, x: torch.Tensor, return output -class LoRARowParallelLinear(LoRALayer): +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: RowParallelLinear) -> None: super().__init__() @@ -824,7 +774,7 @@ def weight(self): return self.base_layer.weight -class LoRASampler(LoRALayer): +class SamplerWithLoRA(BaseLayerWithLoRA): def __init__( self, @@ -984,16 +934,17 @@ def forward(self, *args, **kwargs): return type(self.base_layer).forward(self, *args, **kwargs) -def from_layer(layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None) -> LoRALayer: +def from_layer( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: supported_layer_types = { - VocabParallelEmbedding: LoRAVocabParallelEmbedding, - ColumnParallelLinear: LoRAColumnParallelLinear, - QKVParallelLinear: LoRAQKVParallelLinear, - MergedColumnParallelLinear: LoRAMergedColumnParallelLinear2Slice, - RowParallelLinear: LoRARowParallelLinear, + VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLora, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, } for src_layer_type, lora_layer_type in supported_layer_types.items(): if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck @@ -1009,8 +960,8 @@ def from_layer_sampler( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None, -) -> LoRASampler: - ret = LoRASampler(layer, lm_head.embedding_dim, lm_head.weight.dtype, - lm_head.weight.device) +) -> SamplerWithLoRA: + ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, + lm_head.weight.device) ret.create_lora_weights(max_loras, lora_config, model_config) return ret diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py index 11006c4e1a1c9..fbb228c9582d4 100644 --- a/vllm/lora/lora.py +++ b/vllm/lora/lora.py @@ -1,10 +1,11 @@ from typing import List, Optional import torch +from vllm.utils import in_wsl -class LoRA: - """A LoRA that is composed of two low rank matrixes.""" +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" def __init__( self, @@ -28,29 +29,7 @@ def __init__( else: self.scaling = scaling - @classmethod - def pack(cls, loras: List["LoRA"]) -> "PackedLoRA": - """Pack a list of LoRAs into a single LoRA. - - If LoRA is None, it signifies that the submodule does not have a LoRA. - """ - first_lora = next(lora for lora in loras if lora is not None) - for lora in loras: - if lora is None: - continue - lora.optimize() - rank = first_lora.rank - module_name = first_lora.module_name - obj = PackedLoRA( - module_name, - rank, - [lora.lora_alpha if lora is not None else None for lora in loras], - [lora.lora_a if lora is not None else None for lora in loras], - [lora.lora_b if lora is not None else None for lora in loras], - scaling=[1 if lora is not None else None for lora in loras]) - return obj - - def optimize(self) -> "LoRA": + def optimize(self) -> "LoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" if self.scaling == 1: return @@ -75,8 +54,42 @@ def extra_vocab_size(self) -> int: return self.embeddings_tensor.shape[ 0] if self.embeddings_tensor is not None else 0 + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and not in_wsl() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + -class PackedLoRA(LoRA): +class PackedLoRALayerWeights(LoRALayerWeights): """LoRA used for packed layers (eg. qkv_proj).""" def __init__( @@ -103,7 +116,29 @@ def __init__( lora_alpha / self.rank for lora_alpha in self.lora_alphas ] - def optimize(self) -> "PackedLoRA": + @classmethod + def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[1 if lora is not None else None for lora in loras]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": """Optimize the LoRA by merging the scaling into lora_b.""" for i in range(len(self.lora_b)): if self.scaling[i] == 1 or self.lora_b[i] is None: diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 2f7c6154c7a24..df3d92aa3eef2 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -14,12 +14,14 @@ from vllm.config import LoRAConfig from vllm.utils import LRUCache, in_wsl -from vllm.lora.layers import LoRALayer, LoRAMapping, from_layer, from_layer_sampler -from vllm.lora.lora import LoRA +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule logger = logging.getLogger(__name__) +# TODO: The mappings below should be moved to individual model classes. + PACKED_MODULES_CFG = { "qkv_proj": [ "q_proj", @@ -131,38 +133,6 @@ def get_lora_id(): return _GLOBAL_LORA_ID -def _create_dummy_lora(module_name: str, - input_dim: int, - output_dim: int, - rank: int, - dtype: torch.dtype, - device: torch.device, - embeddings_tensor_dim: Optional[int] = None) -> "LoRA": - pin_memory = str(device) == "cpu" and not in_wsl() - lora_a = torch.zeros([input_dim, rank], - dtype=dtype, - device=device, - pin_memory=pin_memory) - lora_b = torch.zeros([rank, output_dim], - dtype=dtype, - device=device, - pin_memory=pin_memory) - embeddings_tensor = torch.rand( - 10, - embeddings_tensor_dim, - dtype=dtype, - device=device, - pin_memory=pin_memory) if embeddings_tensor_dim else None - return LoRA( - module_name, - rank=rank, - lora_alpha=1, - lora_a=lora_a, - lora_b=lora_b, - embeddings_tensor=embeddings_tensor, - ) - - class LoRAModel: """A LoRA fine-tuned model.""" @@ -170,20 +140,20 @@ def __init__( self, lora_model_id: int, rank: int, - loras: Dict[str, LoRA], + loras: Dict[str, LoRALayerWeights], ) -> None: self.id = lora_model_id assert (lora_model_id > 0), f"a valid lora id should be greater than 0, got {self.id}" self.rank = rank - self.loras: Dict[str, LoRA] = loras + self.loras: Dict[str, LoRALayerWeights] = loras @property def extra_vocab_size(self) -> int: return max(lora.extra_vocab_size for lora in self.loras.values()) if self.loras else 0 - def get_lora(self, module_name: str) -> Optional[LoRA]: + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: """Get LoRA for a given module by name""" return self.loras.get(module_name, None) @@ -202,7 +172,7 @@ def from_lora_tensors( ) -> "LoRAModel": """Create a LoRAModel from a dictionary of tensors.""" pin_memory = str(device) == "cpu" and not in_wsl() - loras: Dict[str, LoRA] = {} + loras: Dict[str, LoRALayerWeights] = {} for tensor_name, tensor in tensors.items(): module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) if module_name not in loras: @@ -218,8 +188,9 @@ def from_lora_tensors( if pin_memory: lora_embeddings_tensor = ( lora_embeddings_tensor.pin_memory()) - loras[module_name] = LoRA(module_name, rank, lora_alpha, None, - None, lora_embeddings_tensor) + loras[module_name] = LoRALayerWeights(module_name, rank, + lora_alpha, None, None, + lora_embeddings_tensor) if is_lora_a: loras[module_name].lora_a = tensor.to(device=device, dtype=dtype).t() @@ -241,7 +212,7 @@ def from_lora_tensors( loras[module_name].lora_b = loras[ module_name].lora_b.pin_memory() - for _, lora in loras.items(): + for lora in loras.values(): lora.optimize() return cls(lora_model_id, rank, loras) @@ -343,6 +314,9 @@ def __init__( dtype=torch.long, device="cuda") self.offsets = [] + # 4 is the number of indicies tensors defined above + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices self.indices_len = [None] * 4 self.model: nn.Module = model @@ -352,8 +326,9 @@ def __init__( self.lora_target_modules = copy.deepcopy(lora_target_modules) self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) self.packed_modules: Dict[str, List[str]] = {} - self.modules: Dict[str, "LoRALayer"] = {} + self.modules: Dict[str, "BaseLayerWithLoRA"] = {} self._registered_loras: Dict[int, LoRAModel] = {} + # Dict instead of a Set for compatibility with LRUCache. self._active_loras: Dict[int, None] = {} self._last_mapping = None self._create_lora_modules() @@ -374,6 +349,7 @@ def activate_lora( self, lora_id: int, ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" if lora_id in self._active_loras: return False first_free_slot = next( @@ -405,6 +381,7 @@ def _deactivate_lora(self, lora_id: int): pass def deactivate_lora(self, lora_id: int) -> bool: + """Remove a LoRA from a GPU buffer.""" if lora_id in self._active_loras: self._deactivate_lora(lora_id) self._active_loras.pop(lora_id) @@ -412,7 +389,7 @@ def deactivate_lora(self, lora_id: int) -> bool: return False def add_lora(self, lora: LoRAModel) -> bool: - """Add a LoRAModel to the manager.""" + """Add a LoRAModel to the manager CPU cache.""" if lora.id not in self._registered_loras: if len(self._registered_loras) >= self.capacity: raise RuntimeError("No free LoRA slots.") @@ -422,13 +399,13 @@ def add_lora(self, lora: LoRAModel) -> bool: return False def remove_lora(self, lora_id: int) -> bool: - """Remove a LoRAModel from the manager.""" + """Remove a LoRAModel from the manager CPU cache.""" # TODO: should we check active lora? self.deactivate_lora(lora_id) return bool(self._registered_loras.pop(lora_id, None)) # TODO see if this can be vectorized - def convert_mapping(self, mapping: LoRAMapping) -> None: + def _set_lora_mapping(self, mapping: LoRAMapping) -> None: (base_indices, sampler_indices, sampler_indices_padded, embeddings_indices, indices_len) = convert_mapping(mapping, self.lora_index_to_id, @@ -444,9 +421,9 @@ def convert_mapping(self, mapping: LoRAMapping) -> None: # Maintain the reference self.indices_len[:] = indices_len - def set_row_lora_mapping(self, lora_mapping: LoRAMapping) -> None: + def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: if self._last_mapping != lora_mapping: - self.convert_mapping(lora_mapping) + self._set_lora_mapping(lora_mapping) self._last_mapping = lora_mapping def list_loras(self) -> Dict[int, LoRAModel]: @@ -484,8 +461,8 @@ def _create_lora_modules(self): self.sampler_indices_padded, self.embeddings_indices, self.indices_len) - def register_module(self, module_name: str, module: "LoRALayer"): - assert isinstance(module, LoRALayer) + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) self.modules[module_name] = module def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: @@ -493,7 +470,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: model = LoRAModel(lora_id, rank, {}) for module_name, module in self.model.named_modules(): if not self._match_target_modules(module_name) or not isinstance( - module, LoRALayer): + module, BaseLayerWithLoRA): continue parts = module_name.split(".") if module_name not in self.packed_modules: @@ -509,7 +486,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: hasattr(module.base_layer, "embedding_dim") else module.base_layer.weight.shape[1]) - lora = _create_dummy_lora( + lora = LoRALayerWeights.create_dummy_lora_weights( module_name, input_dim, output_dim, @@ -518,7 +495,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: "cpu", embeddings_tensor_dim=embeddings_tensor_dim) else: - lora = _create_dummy_lora( + lora = LoRALayerWeights.create_dummy_lora_weights( module_name, module.lora_a_stacked.shape[-1], module.lora_b_stacked.shape[-2], @@ -532,7 +509,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: replacements = self.packed_modules_mapping[parts[-1]] subloras = [] for i, r in enumerate(replacements): - lora = _create_dummy_lora( + lora = LoRALayerWeights.create_dummy_lora_weights( module_name + "." + r, module.lora_a_stacked[i].shape[-1], module.lora_b_stacked[i].shape[-2], @@ -542,7 +519,7 @@ def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: ) lora.optimize() subloras.append(lora) - lora = LoRA.pack(subloras) + lora = PackedLoRALayerWeights.pack(subloras) model.loras[module_name] = lora return model @@ -579,7 +556,8 @@ def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: if replacement_loras[i]: continue replacement_loras[i] = None - lora_model.loras[module_name] = LoRA.pack(replacement_loras) + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) class LoRALRUCache(LRUCache): @@ -647,16 +625,15 @@ def remove_oldest_lora(self) -> bool: return False -def create_lora_adapter( +def create_lora_manager( model: nn.Module, max_num_seqs: int, max_num_batched_tokens: int, vocab_size: int, - lora_config:LoRAConfig, - target_modules: Union[str, - List[str]] = TARGET_MODULES_QKV, - lora_manager_cls:Type[LoRAModelManager] = LoRAModelManager, **kwargs)\ - -> LoRAModelManager: + lora_config: LoRAConfig, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: """Create a LoRA adapter for a given model.""" if not getattr(model, "supports_lora", False): raise ValueError(f"Model {type(model)} is not supported for LoRA.") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py index 363b7770be178..a507c08588dad 100644 --- a/vllm/lora/worker_manager.py +++ b/vllm/lora/worker_manager.py @@ -5,7 +5,7 @@ import torch from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, - LRUCacheLoRAModelManager, create_lora_adapter) + LRUCacheLoRAModelManager, create_lora_manager) from vllm.lora.request import LoRARequest from vllm.lora.layers import LoRAMapping from vllm.config import LoRAConfig @@ -13,7 +13,7 @@ logger = logging.getLogger(__name__) -class AbstractWorkerLoRAManager(ABC): +class WorkerLoRAManager(ABC): """Abstract class for managing LoRA models on the worker side.""" def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, @@ -30,7 +30,7 @@ def is_enabled(self) -> bool: ... @abstractmethod - def create_lora_adapter( + def create_lora_manager( self, model: torch.nn.Module, target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, @@ -38,8 +38,8 @@ def create_lora_adapter( ... @abstractmethod - def apply_loras(self, lora_requests: List[LoRARequest], - lora_mapping: LoRAMapping) -> None: + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: ... @abstractmethod @@ -63,41 +63,7 @@ def list_loras(self) -> Set[int]: ... -class DisabledWorkerLoRAManager(AbstractWorkerLoRAManager): - """WorkerLoRAManager that does nothing.""" - - @property - def is_enabled(self) -> bool: - return False - - def create_lora_adapter( - self, - model: torch.nn.Module, - target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, - ) -> Any: - return model - - def apply_loras(self, lora_requests: List[LoRARequest], - lora_mapping: LoRAMapping) -> None: - return - - def add_lora(self, lora_request: LoRARequest) -> bool: - return False - - def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: - return False - - def remove_lora(self, lora_id: int) -> bool: - return False - - def remove_all_loras(self) -> bool: - return - - def list_loras(self) -> Set[int]: - return set() - - -class WorkerLoRAManager(AbstractWorkerLoRAManager): +class WorkerLoRAManager(WorkerLoRAManager): """WorkerLoRAManager that manages LoRA models on the worker side. Every request, the requested LoRAs will be loaded (unless they are already @@ -123,12 +89,12 @@ def __init__( def is_enabled(self) -> bool: return True - def create_lora_adapter( + def create_lora_manager( self, model: torch.nn.Module, target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, ) -> Any: - lora_manager = create_lora_adapter( + lora_manager = create_lora_manager( model, max_num_seqs=self.max_num_seqs, max_num_batched_tokens=self.max_num_batched_tokens, @@ -140,10 +106,10 @@ def create_lora_adapter( self._lora_manager: LoRAModelManager = lora_manager return lora_manager.model - def apply_loras(self, lora_requests: List[LoRARequest], - lora_mapping: LoRAMapping) -> None: + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: self._apply_loras(lora_requests) - self._lora_manager.set_row_lora_mapping(lora_mapping) + self._lora_manager.set_lora_mapping(lora_mapping) def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: loras_that_exist = self.list_loras() @@ -226,12 +192,12 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager): _lora_manager_cls: Type[ LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager - def create_lora_adapter( + def create_lora_manager( self, model: torch.nn.Module, target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, ) -> Any: - lora_manager = create_lora_adapter( + lora_manager = create_lora_manager( model, target_modules=target_modules, lora_manager_cls=self._lora_manager_cls, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 9903cb540d438..0f1125e5c8e3e 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -63,7 +63,7 @@ def get_model(model_config: ModelConfig, # Create a model instance. # The weights will be initialized as empty tensors. with torch.device("cuda"): - if getattr(model_class, "supports_lora", True): + if getattr(model_class, "supports_lora", False): model = model_class(model_config.hf_config, linear_method, lora_config) elif lora_config: diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index e20da8b5a2ad5..3552c5d665668 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -10,10 +10,7 @@ from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata -from vllm.lora.worker_manager import ( - DisabledWorkerLoRAManager, - LRUCacheWorkerLoRAManager, -) +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest from vllm.utils import in_wsl @@ -77,12 +74,7 @@ def load_model(self) -> None: self.scheduler_config.max_num_seqs, self.scheduler_config.max_num_batched_tokens, vocab_size, self.lora_config, self.device) - self.model = self.lora_manager.create_lora_adapter(self.model) - else: - self.lora_manager = DisabledWorkerLoRAManager( - self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, vocab_size, - self.lora_config, self.device) + self.model = self.lora_manager.create_lora_manager(self.model) def set_block_size(self, block_size: int) -> None: self.block_size = block_size @@ -409,7 +401,7 @@ def execute_model( flat_lora_index_mapping, lora_prompt_mapping, ) - self.apply_loras(lora_requests, lora_mapping) + self.set_active_loras(lora_requests, lora_mapping) # Execute the model. if input_metadata.use_cuda_graph: @@ -492,9 +484,9 @@ def profile_run(self) -> None: def remove_all_loras(self) -> bool: return self.lora_manager.remove_all_loras() - def apply_loras(self, lora_requests: List[LoRARequest], - lora_mapping: LoRAMapping) -> None: - self.lora_manager.apply_loras(lora_requests, lora_mapping) + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self.lora_manager.set_active_loras(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: return self.lora_manager.add_lora(lora_request) diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 49d7fdbb32c71..f1cce3f83527f 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -49,7 +49,6 @@ def __init__( self.cache_engine = None self.cache_events = None self.gpu_cache = None - self.lora_manager = None def init_model(self) -> None: # torch.distributed.all_reduce does not free the input tensor until From 6b2e6a51ec5d3ceaec66c8c62dd081148d65426a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 22 Dec 2023 16:06:15 -0800 Subject: [PATCH 23/35] Fixes --- vllm/lora/layers.py | 13 +++++++------ vllm/worker/model_runner.py | 10 +++++++++- 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 252909c859628..8f16c1ecb330e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -89,12 +89,13 @@ def _apply_lora_packed_nslice( x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) - add_lora_slice(output, x, lora_a_stacked[0], lora_b_stacked[0], indices, 0, - 1.0, 0, output_slices[0]) - add_lora_slice(output, x, lora_a_stacked[1], lora_b_stacked[1], indices, 0, - 1.0, output_slices[0], output_slices[1]) - add_lora_slice(output, x, lora_a_stacked[2], lora_b_stacked[2], indices, 0, - 1.0, output_slices[0] + output_slices[1], output_slices[1]) + output_slices = (0, ) + output_slices + for slice_idx_right in range(1, len(output_slices)): + slice_idx_left = slice_idx_right - 1 + add_lora_slice(output, x, lora_a_stacked[slice_idx_left], + lora_b_stacked[slice_idx_left], indices, 0, 1.0, + output_slices[slice_idx_left], + output_slices[slice_idx_right]) return output.view_as(org_output) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 3552c5d665668..2ad56365c97ad 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -72,7 +72,8 @@ def load_model(self) -> None: if self.lora_config: self.lora_manager = LRUCacheWorkerLoRAManager( self.scheduler_config.max_num_seqs, - self.scheduler_config.max_num_batched_tokens, vocab_size, + self.scheduler_config.max_num_batched_tokens + + self.scheduler_config.max_paddings, vocab_size, self.lora_config, self.device) self.model = self.lora_manager.create_lora_manager(self.model) @@ -532,6 +533,13 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: use_cuda_graph=True, ) + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + graph_runner = CUDAGraphRunner(self.model) graph_runner.capture( input_tokens[:batch_size], From 4b2224eaac2d80f5eeed1b934dd1473d278e32d8 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 22 Dec 2023 17:13:49 -0800 Subject: [PATCH 24/35] Fixes --- tests/lora/test_lora.py | 8 +- tests/lora/test_lora_manager.py | 158 ++++++++++++++++---------------- vllm/engine/arg_utils.py | 4 +- vllm/lora/layers.py | 18 ++-- vllm/lora/models.py | 14 ++- vllm/worker/model_runner.py | 10 ++ vllm/worker/worker.py | 3 +- 7 files changed, 116 insertions(+), 99 deletions(-) diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py index 1b972cc53f24d..3415d36b7e341 100644 --- a/tests/lora/test_lora.py +++ b/tests/lora/test_lora.py @@ -126,7 +126,7 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: input, lora_a_stacks, lora_b_stacks, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (m // 2, )) + device="cuda"), output, (m // 2, m // 2)) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) @@ -134,7 +134,7 @@ def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: output[:] = 0 _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, torch.full((len(input), ), -1, device="cuda"), - output, (m // 2, )) + output, (m // 2, m // 2)) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() @@ -210,7 +210,7 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: input, lora_a_stacks, lora_b_stacks, torch.randint(0, lora_a_stacks[0].shape[0], (len(input), ), - device="cuda"), output, (qkv[0], qkv[1])) + device="cuda"), output, (qkv[0], qkv[1], qkv[2])) rtol, atol = TOLERANCES[dtype] assert torch.allclose(expected, output, rtol=rtol, atol=atol) @@ -218,7 +218,7 @@ def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: output[:] = 0 _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, torch.full((len(input), ), -1, device="cuda"), - output, (qkv[0], qkv[1])) + output, (qkv[0], qkv[1], qkv[2])) assert torch.allclose(torch.zeros_like(output), output) manager.reset_lora() diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py index 9c52058ff9a51..78a4a5bc5ecd2 100644 --- a/tests/lora/test_lora_manager.py +++ b/tests/lora/test_lora_manager.py @@ -121,40 +121,40 @@ def test_lora_model_manager(dist_init, dummy_model): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), lora_target_modules=["dense1", "dense2", "lm_head"]) - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) assert manager.add_lora(model_lora1) assert manager.activate_lora(1) - assert manager.lora_id_to_index[0] == 1 + assert manager.lora_index_to_id[0] == 1 assert not manager.add_lora(model_lora1) assert not manager.activate_lora(1) assert manager.add_lora(model_lora2) assert manager.activate_lora(2) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 assert not manager.add_lora(model_lora2) assert not manager.activate_lora(2) assert manager.add_lora(model_lora3) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 with pytest.raises(ValueError): assert manager.activate_lora(3) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 assert manager.remove_lora(model_lora2.id) - assert manager.lora_id_to_index[1] is None + assert manager.lora_index_to_id[1] is None assert not manager.remove_lora(model_lora2.id) assert manager.remove_lora(model_lora1.id) assert not manager.remove_lora(model_lora1.id) assert manager.add_lora(model_lora1) - assert manager.lora_id_to_index[0] is None - assert manager.lora_id_to_index[1] is None + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] is None assert manager.add_lora(model_lora2) assert manager.activate_lora(3) - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] is None + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] is None assert manager.activate_lora(2) - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 def test_lora_lru_cache_model_manager(dist_init, dummy_model): @@ -169,43 +169,43 @@ def test_lora_lru_cache_model_manager(dist_init, dummy_model): 2, LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), lora_target_modules=["dense1", "dense2", "lm_head"]) - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) assert manager.add_lora(model_lora1) assert manager.activate_lora(1) - assert manager.lora_id_to_index[0] == 1 + assert manager.lora_index_to_id[0] == 1 assert not manager.add_lora(model_lora1) assert not manager.activate_lora(1) assert manager.add_lora(model_lora2) assert manager.activate_lora(2) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 assert not manager.add_lora(model_lora2) assert not manager.activate_lora(2) assert manager.add_lora(model_lora3) - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 assert manager.activate_lora(3) - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 assert manager.remove_lora(model_lora2.id) - assert manager.lora_id_to_index[1] is None + assert manager.lora_index_to_id[1] is None assert not manager.remove_lora(model_lora2.id) assert manager.remove_lora(model_lora1.id) assert not manager.remove_lora(model_lora1.id) assert manager.add_lora(model_lora1) assert manager.activate_lora(1) - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 1 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 1 assert manager.add_lora(model_lora2) assert manager.deactivate_lora(3) - assert manager.lora_id_to_index[0] is None - assert manager.lora_id_to_index[1] == 1 + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 1 assert manager.activate_lora(2) - assert manager.lora_id_to_index[0] == 2 - assert manager.lora_id_to_index[1] == 1 + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 1 assert manager.activate_lora(3) - assert manager.lora_id_to_index[0] == 2 - assert manager.lora_id_to_index[1] == 3 + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 3 def test_lru_lora_model_manager(dist_init, dummy_model): @@ -221,7 +221,7 @@ def test_lru_lora_model_manager(dist_init, dummy_model): LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), ["dense1", "dense2", "lm_head"]) - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) # Add up to capacity assert manager.add_lora(model_lora1) @@ -230,8 +230,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_lora(2) assert set(manager.list_loras()) == {1, 2} - assert manager.lora_id_to_index[0] == 1 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 # Add over capacity assert manager.add_lora(model_lora3) @@ -240,8 +240,8 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_lora(4) assert set(manager.list_loras()) == {3, 4} - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 4 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 4 # Add 3 again to move it to the top and then add 2 # should return false since it's in already @@ -251,16 +251,16 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_lora(2) assert set(manager.list_loras()) == {3, 2} - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 # Remove manually assert manager.remove_lora(3) assert not manager.remove_lora(3) assert set(manager.list_loras()) == {2} - assert manager.lora_id_to_index[0] is None - assert manager.lora_id_to_index[1] == 2 + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 2 assert manager.add_lora(model_lora3) assert manager.activate_lora(3) @@ -268,21 +268,21 @@ def test_lru_lora_model_manager(dist_init, dummy_model): assert manager.activate_lora(4) assert set(manager.list_loras()) == {3, 4} - assert manager.lora_id_to_index[0] == 3 - assert manager.lora_id_to_index[1] == 4 + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 4 assert manager.remove_oldest_lora() assert set(manager.list_loras()) == {4} - assert manager.lora_id_to_index[0] is None - assert manager.lora_id_to_index[1] == 4 + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 4 assert manager.remove_oldest_lora() assert set(manager.list_loras()) == set() - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) assert not manager.remove_oldest_lora() assert set(manager.list_loras()) == set() - assert all(x is None for x in manager.lora_id_to_index) + assert all(x is None for x in manager.lora_index_to_id) def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, @@ -299,8 +299,8 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("2", 2, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -308,10 +308,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("4", 4, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2, 3, 4} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 3 - assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -319,10 +319,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("5", 5, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 - assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -330,10 +330,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("1", 1, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2, 4, 5} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 - assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 4 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 worker_lora_manager.set_active_loras([ LoRARequest("6", 6, sql_lora_files), @@ -341,10 +341,10 @@ def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("8", 8, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 6, 7, 8} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 7 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 8 - assert worker_lora_manager._lora_manager.lora_id_to_index[3] == 6 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6 # Over capacity with pytest.raises(RuntimeError): @@ -372,8 +372,8 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("2", 2, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -381,9 +381,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("4", 4, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 3, 4} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 3 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 4 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -391,9 +391,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("5", 5, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1, 2, 5} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 2 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 worker_lora_manager.set_active_loras([ LoRARequest("1", 1, sql_lora_files), @@ -401,9 +401,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("1", 1, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {1} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 1 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] is None - assert worker_lora_manager._lora_manager.lora_id_to_index[2] is None + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None + assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None worker_lora_manager.set_active_loras([ LoRARequest("6", 6, sql_lora_files), @@ -411,9 +411,9 @@ def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, LoRARequest("8", 8, sql_lora_files) ], mapping) assert worker_lora_manager.list_loras() == {6, 7, 8} - assert worker_lora_manager._lora_manager.lora_id_to_index[0] == 8 - assert worker_lora_manager._lora_manager.lora_id_to_index[1] == 6 - assert worker_lora_manager._lora_manager.lora_id_to_index[2] == 7 + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7 # Over capacity with pytest.raises(RuntimeError): diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 62e5aa5257914..090fa95bcac02 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -279,8 +279,8 @@ def create_engine_configs( max_loras=self.max_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, - max_cpu_loras=self.max_cpu_loras - if self.max_cpu_loras > 0 else None) if self.enable_lora else None + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None return model_config, cache_config, parallel_config, scheduler_config, lora_config diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 8f16c1ecb330e..5c26ce37bbf8d 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -89,13 +89,12 @@ def _apply_lora_packed_nslice( x = x.view(-1, x.shape[-1]) output = output.view(-1, output.shape[-1]) indices = indices.view(-1) - output_slices = (0, ) + output_slices - for slice_idx_right in range(1, len(output_slices)): - slice_idx_left = slice_idx_right - 1 - add_lora_slice(output, x, lora_a_stacked[slice_idx_left], - lora_b_stacked[slice_idx_left], indices, 0, 1.0, - output_slices[slice_idx_left], - output_slices[slice_idx_right]) + offset_left = 0 + for slice_idx in range(len(output_slices)): + add_lora_slice(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] return output.view_as(org_output) @@ -480,7 +479,7 @@ def apply_weights(self, x: torch.Tensor, self.lora_b_stacked, self.indices[:self.indices_len[0]], output, - (self.output_dim, ), + (self.output_dim, self.output_dim), ) return output @@ -563,7 +562,8 @@ def create_lora_weights( device=self.base_layer.weight.device, )) - self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size) + self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, + self.kv_proj_shard_size) self.packed_indices: Optional[torch.Tensor] = None self.standard_indices: Optional[torch.Tensor] = None self.indices_len: Optional[List[int]] = None diff --git a/vllm/lora/models.py b/vllm/lora/models.py index df3d92aa3eef2..6c78c4a2c7771 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -388,13 +388,16 @@ def deactivate_lora(self, lora_id: int) -> bool: return True return False + def _add_lora(self, lora: LoRAModel) -> bool: + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager CPU cache.""" if lora.id not in self._registered_loras: if len(self._registered_loras) >= self.capacity: raise RuntimeError("No free LoRA slots.") - self._create_merged_loras_inplace(lora) - self._registered_loras[lora.id] = lora + self._add_lora(lora) return True return False @@ -600,10 +603,13 @@ def list_loras(self) -> Dict[int, LoRAModel]: def add_lora(self, lora: LoRAModel) -> bool: """Add a LoRAModel to the manager.""" - was_added = super().add_lora(lora) - if not was_added: + if lora.id not in self._registered_loras: + self._add_lora(lora) + was_added = True + else: # We always touch to update the LRU cache order self._registered_loras.touch(lora.id) + was_added = False return was_added def activate_lora( diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 2ad56365c97ad..efbc09be1830d 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -483,19 +483,29 @@ def profile_run(self) -> None: return def remove_all_loras(self) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_all_loras() def set_active_loras(self, lora_requests: List[LoRARequest], lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") self.lora_manager.set_active_loras(lora_requests, lora_mapping) def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") return self.lora_manager.add_lora(lora_request) def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") return self.lora_manager.remove_lora(lora_id) def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") return self.lora_manager.list_loras() @torch.inference_mode() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index f1cce3f83527f..bb8e7fd6cf86e 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -111,7 +111,8 @@ def profile_num_available_blocks( num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) - self.model_runner.remove_all_loras() + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() gc.collect() torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks From 891070257c145b506a20666a3cb70afcf674d4ca Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 4 Jan 2024 16:34:11 -0800 Subject: [PATCH 25/35] Add 5632 to kernel sizes --- csrc/punica/bgmv/bgmv_config.h | 1 + tests/lora/test_punica.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index ced0397dab216..ebf638f104c3f 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -27,6 +27,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 5120) \ f(in_T, out_T, W_T, narrow, 5504) \ + f(in_T, out_T, W_T, narrow, 5632) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 8192) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index f603b06cdb565..903814faa5dc7 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -44,8 +44,8 @@ def _lora_ref_impl( H1 = H2 = [ 128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, - 5504, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000, 32256, - 32512, 32768, 33024 + 5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000, + 32256, 32512, 32768, 33024 ] SEED = [0xabcdabcd987] From affbc069110c1d8f74a42d78638588afeeed41fd Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 18 Jan 2024 13:38:34 -0800 Subject: [PATCH 26/35] Prefix support --- vllm/engine/llm_engine.py | 3 ++- vllm/prefix.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 01dfd61d88825..84303ac66433c 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -438,7 +438,8 @@ def add_request( # Check whether the input specifies prefix prefix = self.scheduler.prefix_pool.add_or_get_prefix( - prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None + prompt_token_ids[:prefix_pos], + lora_request.lora_int_id) if prefix_pos is not None else None # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, diff --git a/vllm/prefix.py b/vllm/prefix.py index 06b5b32a38fcc..5b6e8e4b92be6 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -74,13 +74,14 @@ def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: new_length = len(token_ids) // self.block_size * self.block_size return tuple(token_ids[:new_length]) - def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: + def add_or_get_prefix(self, token_ids: Sequence[int], + lora_int_id: int) -> Optional[Prefix]: token_ids = self._truncate_token_ids(token_ids) if len(token_ids) == 0: # Prefix is empty. return None prefix = Prefix(token_ids, self.block_size) - prefix_hash = hash(prefix) + prefix_hash = hash((prefix, lora_int_id)) if prefix_hash not in self.prefixes: self.prefixes[prefix_hash] = prefix return self.prefixes[prefix_hash] From 444d06062f5c9e46bb7a39a33c8e942d4a15bcc7 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 18 Jan 2024 13:50:26 -0800 Subject: [PATCH 27/35] Address feedback from code review --- .buildkite/test-pipeline.yaml | 3 + examples/offline_inference.py | 2 +- tests/lora/test_tokenizer.py | 6 +- vllm/engine/llm_engine.py | 4 +- vllm/lora/layers.py | 100 +++++++++++++------------- vllm/lora/request.py | 7 +- vllm/model_executor/models/llama.py | 4 +- vllm/model_executor/models/mistral.py | 4 +- vllm/transformers_utils/tokenizer.py | 3 +- vllm/worker/model_runner.py | 2 +- 10 files changed, 72 insertions(+), 63 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index adf2bb2b43c1a..65ac2f74fb8dc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -41,6 +41,9 @@ steps: - label: Worker Test command: pytest -v -s worker +- label: LoRA Test + command: pytest -v -s lora + - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" commands: diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..fd26306773f67 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", enable_lora=True) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py index af0fc41f3fa45..6c4c91fce8127 100644 --- a/tests/lora/test_tokenizer.py +++ b/tests/lora/test_tokenizer.py @@ -2,13 +2,13 @@ from transformers import AutoTokenizer, PreTrainedTokenizerBase from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizer import MultiLoRATokenizer, get_lora_tokenizer +from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer @pytest.mark.asyncio async def test_transformers_tokenizer(): reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") - tokenizer = MultiLoRATokenizer( + tokenizer = TokenizerGroup( tokenizer_id="gpt2", enable_lora=False, max_num_seqs=1, @@ -29,7 +29,7 @@ async def test_transformers_tokenizer(): @pytest.mark.asyncio async def test_transformers_tokenizer_lora(sql_lora_files): reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) - tokenizer = MultiLoRATokenizer( + tokenizer = TokenizerGroup( tokenizer_id="gpt2", enable_lora=True, max_num_seqs=1, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 84303ac66433c..025fb2f4391cf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -18,7 +18,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - MultiLoRATokenizer) + TokenizerGroup) from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port if ray: @@ -156,7 +156,7 @@ def _init_tokenizer(self, **tokenizer_init_kwargs): trust_remote_code=self.model_config.trust_remote_code, revision=self.model_config.tokenizer_revision) init_kwargs.update(tokenizer_init_kwargs) - self.tokenizer: MultiLoRATokenizer = MultiLoRATokenizer( + self.tokenizer: TokenizerGroup = TokenizerGroup( self.model_config.tokenizer, **init_kwargs) def _init_workers_ray(self, placement_group: "PlacementGroup", diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index e711266cbe4d2..e1aac20b038b4 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -514,54 +514,58 @@ def create_lora_weights( self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas # q, k, v - self.lora_a_stacked = (torch.zeros( - max_loras, - 1, - lora_config.max_lora_rank, - self.base_layer.weight.shape[1], - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ), - torch.zeros( - max_loras, - 1, - lora_config.max_lora_rank, - self.base_layer.weight.shape[1], - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ), - torch.zeros( - max_loras, - 1, - lora_config.max_lora_rank, - self.base_layer.weight.shape[1], - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - )) - self.lora_b_stacked = (torch.zeros( - max_loras, - 1, - self.q_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - ), - torch.zeros( - max_loras, - 1, - self.kv_proj_shard_size, - lora_config.max_lora_rank, - dtype=lora_config.lora_dtype, - device=self.base_layer.weight.device, - )) + self.lora_a_stacked = ( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + ) + self.lora_b_stacked = ( + torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + ) self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, self.kv_proj_shard_size) diff --git a/vllm/lora/request.py b/vllm/lora/request.py index 5d45f8a0f396d..bbbf4880ab81b 100644 --- a/vllm/lora/request.py +++ b/vllm/lora/request.py @@ -11,11 +11,11 @@ class LoRARequest: instead provide another layer of abstraction to prevent users from accessing unauthorized LoRA adapters. - lora_id and lora_int_id must be globally unique for a given adapter. + lora_int_id must be globally unique for a given adapter. This is currently not enforced in vLLM. """ - lora_id: str + lora_name: str lora_int_id: int lora_local_path: str @@ -25,7 +25,8 @@ def __post_init__(self): f"lora_int_id must be > 0, got {self.lora_int_id}") def __eq__(self, value: object) -> bool: - return isinstance(value, LoRARequest) and self.lora_id == value.lora_id + return isinstance( + value, LoRARequest) and self.lora_int_id == value.lora_int_id def __hash__(self) -> int: return self.lora_int_id diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c923804c57611..e5a1abebf1420 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -288,9 +288,9 @@ def __init__( unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - padding_size=64 if not lora_config else lora_config.lora_vocab_padding_size, ) self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 80e609f6e4a52..01cde67844122 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -38,7 +38,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -286,9 +286,9 @@ def __init__( unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE # We need bigger padding if using lora for kernel # compatibility - padding_size=64 if not lora_config else lora_config.lora_vocab_padding_size, ) self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 532c7a4e6c1dc..6edc225cdfc80 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -88,7 +88,8 @@ def get_lora_tokenizer(lora_request: LoRARequest, *args, get_lora_tokenizer_async = make_async(get_lora_tokenizer) -class MultiLoRATokenizer: +class TokenizerGroup: + """A group of tokenizers that can be used for LoRA adapters.""" def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, max_input_length: Optional[int], **tokenizer_config): diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c275cb3ade391..25ed6a579d0f6 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -635,7 +635,7 @@ def profile_run(self) -> None: for idx in range(self.lora_config.max_loras): lora_id = idx + 1 dummy_lora_request = LoRARequest( - lora_id=f"warmup_{lora_id}", + lora_name=f"warmup_{lora_id}", lora_int_id=lora_id, lora_local_path="/not/a/real/path", ) From 077dd2155d8f1282c694e87e9ac250065c25f612 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 18 Jan 2024 14:03:36 -0800 Subject: [PATCH 28/35] Fix --- vllm/worker/model_runner.py | 1 + 1 file changed, 1 insertion(+) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 25ed6a579d0f6..0ad89671f9b85 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -443,6 +443,7 @@ def prepare_input_tensors( lora_index_mapping, lora_prompt_mapping, lora_requests) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] + subquery_lens = None sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) From 8a1e03b09107de19c054bf8dfda18025d2ba532c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 18 Jan 2024 14:09:58 -0800 Subject: [PATCH 29/35] Fix tests --- tests/lora/conftest.py | 3 +++ vllm/model_executor/parallel_utils/parallel_state.py | 6 +++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index d7d46563fe5dc..54ebf492ea4e3 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -1,3 +1,4 @@ +import contextlib import gc import tempfile from collections import OrderedDict @@ -23,6 +24,8 @@ def cleanup(): destroy_model_parallel() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() gc.collect() torch.cuda.empty_cache() ray.shutdown() diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index 9a5e2889381d9..58188fc9344fc 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -170,10 +170,14 @@ def get_pipeline_model_parallel_prev_rank(): def destroy_model_parallel(): - """Set the groups to none.""" + """Set the groups to none and destroy them.""" global _TENSOR_MODEL_PARALLEL_GROUP + if _TENSOR_MODEL_PARALLEL_GROUP: + torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP) _TENSOR_MODEL_PARALLEL_GROUP = None global _PIPELINE_MODEL_PARALLEL_GROUP + if _PIPELINE_MODEL_PARALLEL_GROUP: + torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None From f823fd91907498c37056ccbeb4e57e56f50f4bcc Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 18 Jan 2024 14:15:52 -0800 Subject: [PATCH 30/35] Fix test --- tests/lora/test_llama.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index 66d4c39e31181..0ddeebd021f29 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -3,6 +3,7 @@ import vllm from vllm.lora.request import LoRARequest +from .conftest import cleanup MODEL_PATH = "meta-llama/Llama-2-7b-hf" @@ -89,7 +90,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) del llm_tp1 - ray.shutdown() + cleanup() llm_tp2 = vllm.LLM(MODEL_PATH, enable_lora=True, @@ -99,7 +100,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) del llm_tp2 - ray.shutdown() + cleanup() assert output_tp1 == output_tp2 @@ -111,7 +112,7 @@ def test_llama_tensor_parallel_equality(sql_lora_files): output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) del llm_tp4 - ray.shutdown() + cleanup() assert output_tp1 == output_tp4 From 2fde821a763628a0cfe529d936f645e176b81e6e Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 18 Jan 2024 14:20:11 -0800 Subject: [PATCH 31/35] Fix tests --- tests/lora/conftest.py | 2 +- tests/lora/test_worker.py | 22 +++++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 54ebf492ea4e3..c1b3d04c713b5 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -140,4 +140,4 @@ def get_model_patched(model_config, lora_config=None): @pytest.fixture def llama_2_7b_model_extra_embeddings( llama_2_7b_engine_extra_embeddings) -> nn.Module: - yield llama_2_7b_engine_extra_embeddings.workers[0].model_runner.model + yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py index 126d910f53ab3..68c2c0b5fc134 100644 --- a/tests/lora/test_worker.py +++ b/tests/lora/test_worker.py @@ -12,17 +12,21 @@ @patch.dict(os.environ, {"RANK": "0"}) def test_worker_apply_lora(sql_lora_files): worker = Worker( - model_config=ModelConfig("meta-llama/Llama-2-7b-hf", - "meta-llama/Llama-2-7b-hf", - tokenizer_mode="auto", - trust_remote_code=False, - download_dir=None, - load_format="dummy", - seed=0, - dtype="float16", - revision=None), + model_config=ModelConfig( + "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + ), parallel_config=ParallelConfig(1, 1, False), scheduler_config=SchedulerConfig(32, 32, 32, 256), + local_rank=0, + rank=0, lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, max_loras=32), distributed_init_method=f"file://{tempfile.mkstemp()[1]}", From 157fa1a60b5085c859b155253406782c9af01c01 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 18 Jan 2024 16:13:34 -0800 Subject: [PATCH 32/35] Fix tests --- tests/async_engine/test_async_llm_engine.py | 7 +++++++ tests/lora/test_llama.py | 5 ++++- tests/samplers/test_sampler.py | 4 ++-- tests/worker/test_model_runner.py | 2 +- vllm/engine/async_llm_engine.py | 2 ++ vllm/engine/llm_engine.py | 4 ++-- 6 files changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 174975802dc0d..abe0b5d64ab9e 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,6 +25,13 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] + async def encode_request_async( + self, + *args, + **kwargs, + ): + return [1] + def generate(self, request_id): self.request_id = request_id diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py index 0ddeebd021f29..06fbf19eea824 100644 --- a/tests/lora/test_llama.py +++ b/tests/lora/test_llama.py @@ -35,8 +35,9 @@ def do_sample(llm, lora_path: str, lora_id: int): return generated_texts -@pytest.mark.parametrize("tp_size", [1, 2, 4]) +@pytest.mark.parametrize("tp_size", [1]) def test_llama_lora(sql_lora_files, tp_size): + # Cannot use as it will initialize torch.cuda too early... # if torch.cuda.device_count() < tp_size: # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") @@ -78,7 +79,9 @@ def test_llama_lora(sql_lora_files, tp_size): print("removing lora") +@pytest.mark.skip("Requires multiple GPUs") def test_llama_tensor_parallel_equality(sql_lora_files): + # Cannot use as it will initialize torch.cuda too early... # if torch.cuda.device_count() < 4: # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index bcd0cd60bfc52..9527093805aea 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -38,7 +38,7 @@ def _prepare_test( device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) - model_runner = ModelRunner(None, None, None) + model_runner = ModelRunner(None, None, None, None) return input_tensor, fake_logits, sampler, model_runner @@ -266,7 +266,7 @@ def test_sampler_top_k_top_p(seed: int): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) - model_runner = ModelRunner(None, None, None) + model_runner = ModelRunner(None, None, None, None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index edbe10684741f..701ee0af7b10c 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -6,7 +6,7 @@ def test_prepare_prompt(): - model_runner = ModelRunner(None, None, None) + model_runner = ModelRunner(None, None, None, None) model_runner.set_block_size(16) batch_size = random.randint(1, 256) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 258b3e17f9e24..c7591945be243 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -227,6 +227,7 @@ async def add_request_async( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, ) -> None: if lora_request is not None and not self.lora_config: raise ValueError(f"Got lora_request {lora_request} but LoRA is " @@ -246,6 +247,7 @@ async def add_request_async( sampling_params=sampling_params, arrival_time=arrival_time, lora_request=lora_request, + prefix_pos=prefix_pos, ) async def _run_workers_async( diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 025fb2f4391cf..aa3383ba7cb71 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -438,8 +438,8 @@ def add_request( # Check whether the input specifies prefix prefix = self.scheduler.prefix_pool.add_or_get_prefix( - prompt_token_ids[:prefix_pos], - lora_request.lora_int_id) if prefix_pos is not None else None + prompt_token_ids[:prefix_pos], lora_request.lora_int_id + if lora_request else 0) if prefix_pos is not None else None # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, From 0a43cd8456095c2d427a2b8901c452ed246f8829 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Fri, 19 Jan 2024 01:14:15 +0100 Subject: [PATCH 33/35] Update examples/offline_inference.py --- examples/offline_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index fd26306773f67..9b758fa2479f6 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -11,7 +11,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="TinyLlama/TinyLlama-1.1B-Chat-v1.0", enable_lora=True) +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) From 52da6cc3cbc5a55fe8028199397f262f215e5c90 Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Thu, 18 Jan 2024 23:24:01 -0800 Subject: [PATCH 34/35] Fix tests --- tests/async_engine/test_async_llm_engine.py | 4 ++++ tests/samplers/test_sampler.py | 9 +++++---- tests/worker/test_model_runner.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index abe0b5d64ab9e..1edb19c550010 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -42,6 +42,10 @@ def add_request(self, **kwargs): del kwargs # Unused self.add_request_calls += 1 + async def add_request_async(self, **kwargs): + del kwargs # Unused + self.add_request_calls += 1 + def abort_request(self, request_id): del request_id # Unused self.abort_request_calls += 1 diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 9527093805aea..962183a29fbfa 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -19,10 +19,11 @@ def __init__(self, vocab_size: int, fake_logits: torch.Tensor): self.fake_logits = fake_logits def forward(self, *args, **kwargs): - with patch("vllm.model_executor.layers.sampler._prune_hidden_states", - lambda x, y: x), patch( - "vllm.model_executor.layers.sampler._get_logits", - lambda *args, **kwargs: self.fake_logits): + with patch( + "vllm.model_executor.layers.sampler._prune_hidden_states", + lambda x, y: x), patch( + "vllm.model_executor.layers.sampler.Sampler._get_logits", + lambda *args, **kwargs: self.fake_logits): return super().forward(*args, **kwargs) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 701ee0af7b10c..5d9ad0520de13 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -33,7 +33,7 @@ def test_prepare_prompt(): expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens, _ = ( + input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = ( model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, From 3ac8bdfaa8ee86f49e1629db3cb8fc9674a0cf9d Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Tue, 23 Jan 2024 12:40:40 -0800 Subject: [PATCH 35/35] Fix CI --- tests/worker/spec_decode/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/worker/spec_decode/utils.py b/tests/worker/spec_decode/utils.py index 812033829394f..e0db770046ec8 100644 --- a/tests/worker/spec_decode/utils.py +++ b/tests/worker/spec_decode/utils.py @@ -83,8 +83,8 @@ def create_worker(cls: type, enforce_eager=enforce_eager, ) - (model_config, cache_config, parallel_config, - scheduler_config) = engine_args.create_engine_configs() + (model_config, cache_config, parallel_config, scheduler_config, + _) = engine_args.create_engine_configs() distributed_init_method = get_distributed_init_method( get_ip(), get_open_port())