From 2f3d535cd586bc8067207d5be8f1be25d0c75c4a Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 24 Jan 2024 00:26:37 +0100 Subject: [PATCH] [Experimental] Add multi-LoRA support (#1804) Co-authored-by: Chen Shen Co-authored-by: Shreyas Krishnaswamy Co-authored-by: Avnish Narayan --- .buildkite/test-pipeline.yaml | 3 + benchmarks/benchmark_latency.py | 10 +- csrc/punica/LICENSE | 217 +++ csrc/punica/bgmv/bgmv_all.cu | 21 + csrc/punica/bgmv/bgmv_config.h | 59 + csrc/punica/bgmv/bgmv_impl.cuh | 294 ++++ csrc/punica/bgmv/vec_dtypes.cuh | 1324 +++++++++++++++++ csrc/punica/punica_ops.cc | 563 +++++++ examples/multilora_inference.py | 117 ++ setup.py | 59 +- tests/async_engine/test_async_llm_engine.py | 11 + tests/lora/__init__.py | 0 tests/lora/conftest.py | 143 ++ tests/lora/test_layers.py | 709 +++++++++ tests/lora/test_llama.py | 144 ++ tests/lora/test_lora.py | 224 +++ tests/lora/test_lora_manager.py | 475 ++++++ tests/lora/test_punica.py | 175 +++ tests/lora/test_tokenizer.py | 69 + tests/lora/test_utils.py | 172 +++ tests/lora/test_worker.py | 61 + tests/lora/utils.py | 88 ++ tests/samplers/test_sampler.py | 13 +- tests/worker/spec_decode/utils.py | 4 +- tests/worker/test_model_runner.py | 4 +- vllm/config.py | 51 +- vllm/core/scheduler.py | 73 +- vllm/engine/arg_utils.py | 53 +- vllm/engine/async_llm_engine.py | 79 +- vllm/engine/llm_engine.py | 111 +- vllm/entrypoints/llm.py | 11 +- vllm/lora/__init__.py | 0 vllm/lora/layers.py | 975 ++++++++++++ vllm/lora/lora.py | 160 ++ vllm/lora/models.py | 654 ++++++++ vllm/lora/punica.py | 173 +++ vllm/lora/request.py | 32 + vllm/lora/utils.py | 39 + vllm/lora/worker_manager.py | 237 +++ vllm/model_executor/layers/sampler.py | 35 +- .../layers/vocab_parallel_embedding.py | 26 +- vllm/model_executor/model_loader.py | 19 +- vllm/model_executor/models/llama.py | 31 +- vllm/model_executor/models/mistral.py | 33 +- .../parallel_utils/parallel_state.py | 6 +- vllm/outputs.py | 19 +- vllm/prefix.py | 5 +- vllm/sequence.py | 22 + vllm/transformers_utils/tokenizer.py | 80 + vllm/utils.py | 90 ++ vllm/worker/model_runner.py | 161 +- vllm/worker/worker.py | 27 +- 52 files changed, 8035 insertions(+), 126 deletions(-) create mode 100644 csrc/punica/LICENSE create mode 100644 csrc/punica/bgmv/bgmv_all.cu create mode 100644 csrc/punica/bgmv/bgmv_config.h create mode 100644 csrc/punica/bgmv/bgmv_impl.cuh create mode 100644 csrc/punica/bgmv/vec_dtypes.cuh create mode 100644 csrc/punica/punica_ops.cc create mode 100644 examples/multilora_inference.py create mode 100644 tests/lora/__init__.py create mode 100644 tests/lora/conftest.py create mode 100644 tests/lora/test_layers.py create mode 100644 tests/lora/test_llama.py create mode 100644 tests/lora/test_lora.py create mode 100644 tests/lora/test_lora_manager.py create mode 100644 tests/lora/test_punica.py create mode 100644 tests/lora/test_tokenizer.py create mode 100644 tests/lora/test_utils.py create mode 100644 tests/lora/test_worker.py create mode 100644 tests/lora/utils.py create mode 100644 vllm/lora/__init__.py create mode 100644 vllm/lora/layers.py create mode 100644 vllm/lora/lora.py create mode 100644 vllm/lora/models.py create mode 100644 vllm/lora/punica.py create mode 100644 vllm/lora/request.py create mode 100644 vllm/lora/utils.py create mode 100644 vllm/lora/worker_manager.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index adf2bb2b43c1a..65ac2f74fb8dc 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -41,6 +41,9 @@ steps: - label: Worker Test command: pytest -v -s worker +- label: LoRA Test + command: pytest -v -s lora + - label: Benchmarks working_dir: "/vllm-workspace/.buildkite" commands: diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index e33d5fb2dc247..d75d690cc66d4 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -65,7 +65,9 @@ def run_to_completion(profile_dir: Optional[str] = None): if args.profile: profile_dir = args.profile_result_dir if not profile_dir: - profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" + profile_dir = Path( + "." + ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" print(f"Profiling (results will be saved to '{profile_dir}')...") run_to_completion(profile_dir=args.profile_result_dir) return @@ -123,9 +125,7 @@ def run_to_completion(profile_dir: Optional[str] = None): '--profile-result-dir', type=str, default=None, - help=( - 'path to save the pytorch profiler output. Can be visualized ' - 'with ui.perfetto.dev or Tensorboard.' - )) + help=('path to save the pytorch profiler output. Can be visualized ' + 'with ui.perfetto.dev or Tensorboard.')) args = parser.parse_args() main(args) diff --git a/csrc/punica/LICENSE b/csrc/punica/LICENSE new file mode 100644 index 0000000000000..a46e2cdcadf7d --- /dev/null +++ b/csrc/punica/LICENSE @@ -0,0 +1,217 @@ +Contains code from https://github.com/punica-ai/punica + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "{}" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright {yyyy} {name of copyright owner} + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +------------------------------------------------------------------------------------ + +This product bundles various third-party components under other open source licenses. +This section summarizes those components and their licenses. See licenses/ +for text of these licenses. + + +Apache-2.0 +* third_party/nvbench (with LLVM exception) +* third_party/flashinfer + +BSD-3-Clause: +* third_party/cutlass \ No newline at end of file diff --git a/csrc/punica/bgmv/bgmv_all.cu b/csrc/punica/bgmv/bgmv_all.cu new file mode 100644 index 0000000000000..2502a67e3c813 --- /dev/null +++ b/csrc/punica/bgmv/bgmv_all.cu @@ -0,0 +1,21 @@ +#include "bgmv_config.h" +#include "bgmv_impl.cuh" + +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) +FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h new file mode 100644 index 0000000000000..ebf638f104c3f --- /dev/null +++ b/csrc/punica/bgmv/bgmv_config.h @@ -0,0 +1,59 @@ +#pragma once + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale); + +// clang-format off + +#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, narrow, 128) \ + f(in_T, out_T, W_T, narrow, 256) \ + f(in_T, out_T, W_T, narrow, 512) \ + f(in_T, out_T, W_T, narrow, 1024) \ + f(in_T, out_T, W_T, narrow, 1280) \ + f(in_T, out_T, W_T, narrow, 1728) \ + f(in_T, out_T, W_T, narrow, 1792) \ + f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2560) \ + f(in_T, out_T, W_T, narrow, 2752) \ + f(in_T, out_T, W_T, narrow, 3072) \ + f(in_T, out_T, W_T, narrow, 3456) \ + f(in_T, out_T, W_T, narrow, 3584) \ + f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 5120) \ + f(in_T, out_T, W_T, narrow, 5504) \ + f(in_T, out_T, W_T, narrow, 5632) \ + f(in_T, out_T, W_T, narrow, 6912) \ + f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 8192) \ + f(in_T, out_T, W_T, narrow, 9216) \ + f(in_T, out_T, W_T, narrow, 10240) \ + f(in_T, out_T, W_T, narrow, 11008) \ + f(in_T, out_T, W_T, narrow, 12288) \ + f(in_T, out_T, W_T, narrow, 13824) \ + f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 16384) \ + f(in_T, out_T, W_T, narrow, 20480) \ + f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 32000) \ + f(in_T, out_T, W_T, narrow, 32256) \ + f(in_T, out_T, W_T, narrow, 32512) \ + f(in_T, out_T, W_T, narrow, 32768) \ + f(in_T, out_T, W_T, narrow, 33024) \ + f(in_T, out_T, W_T, narrow, 36864) \ + f(in_T, out_T, W_T, narrow, 49152) \ +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + +// Keep this in sync with vllm/config::LoRAConfig +#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ + FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + +// clang-format on diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh new file mode 100644 index 0000000000000..995de26e8bada --- /dev/null +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -0,0 +1,294 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include "vec_dtypes.cuh" + +namespace cg = cooperative_groups; + +// nthrs = (32, 4) +template +__global__ void +bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t j = blockIdx.x; + constexpr size_t num_pipeline_stages = 2; + constexpr size_t tile_size = tx * ty * vec_size; + __shared__ W_T W_shared[num_pipeline_stages * tile_size]; + __shared__ in_T X_shared[num_pipeline_stages * tile_size]; + __shared__ float y_warpwise[ty]; + + size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; + auto pipe = cuda::make_pipeline(); + + // pipeline load W/X and compute WX; + pipe.producer_acquire(); + cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + pipe.producer_commit(); + size_t copy_idx, compute_idx; + float y = 0.f; + vec_t x_vec; + vec_t w_vec; + size_t tile_idx; + +#pragma unroll + for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; + ++tile_idx) { + copy_idx = tile_idx % num_pipeline_stages; + // pipeline stage: async copy W fragment + pipe.producer_acquire(); + if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { + cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + W + (idx * feat_out + j) * feat_in + + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(W_copy_size), pipe); + cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size, + X + (batch_idx * feat_in) + tile_idx * tile_size + + (threadIdx.y * tx + threadIdx.x) * vec_size, + cuda::aligned_size_t(X_copy_size), pipe); + } + pipe.producer_commit(); + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // pipeline stage: compute WX + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = sum; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + } + + compute_idx = (tile_idx - 1) % num_pipeline_stages; + // final pipeline stage + pipe.consumer_wait(); + block.sync(); + x_vec.load(X_shared + X_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + w_vec.load(W_shared + W_shared_offset[compute_idx] + + (threadIdx.y * tx + threadIdx.x) * vec_size); + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += __shfl_down_sync(0xffffffff, sum, offset); + } + y_warpwise[threadIdx.y] = + ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) + ? sum + : 0.f; + block.sync(); +#pragma unroll + for (size_t i = 0; i < ty; ++i) { + y += y_warpwise[i]; + } + + block.sync(); + pipe.consumer_release(); + + // write Y; + if (block.thread_rank() == 0) { + Y[batch_idx * full_y_size + y_offset + j] += static_cast(y); + } +} + +// nthrs = (2, 16, 4) +template +__global__ void +bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t num_layers, int64_t layer_idx, + float scale) { + size_t batch_idx = blockIdx.y; + int64_t idx = indicies[batch_idx] * num_layers + layer_idx; + + if (idx < 0) { + return; + } + + auto block = cg::this_thread_block(); + size_t tile_idx = blockIdx.x; + + // load X; + vec_t x_vec; + x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); + + // load W; + vec_t w_vec; + w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + + block.thread_rank() * vec_size); + + float sum = 0.f; +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + sum += float(w_vec[i]) * float(x_vec[i]) * scale; + } + + cg::thread_block_tile g = cg::tiled_partition(block); +#pragma unroll + for (size_t offset = tx / 2; offset > 0; offset /= 2) { + sum += g.shfl_down(sum, offset); + } + sum = g.shfl(sum, 0); + + if (threadIdx.x == 0) { + Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + + threadIdx.z * ty + threadIdx.y] += static_cast(sum); + } +} + +template +void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, + const W_T *__restrict__ W, + const int64_t *__restrict__ indicies, int64_t y_offset, + int64_t full_y_size, int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + constexpr size_t vec_size = 8; + constexpr int tz = 4; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if constexpr (feat_in < feat_out) { + static_assert(feat_in % vec_size == 0); + constexpr int tx = feat_in / vec_size; + + static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || + (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || + (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); + + if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { + constexpr int ty = 32 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { + constexpr int ty = 16 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else { + constexpr int ty = 8 / tx; + dim3 nblks(feat_out / (ty * tz), batch_size); + dim3 nthrs(tx, ty, tz); + + bgmv_expand_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } else { + static_assert(feat_in % (vec_size * 32) == 0 || + feat_in % (vec_size * 16) == 0 || + feat_in % (vec_size * 8) == 0); + + if constexpr (feat_in % (vec_size * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { + constexpr int tx = 32; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { + constexpr int tx = 16; + constexpr int ty = 4; + + dim3 nblks(feat_out, batch_size); + dim3 nthrs(tx, ty); + + bgmv_shrink_kernel + <<>>(Y, X, W, indicies, y_offset, + full_y_size, num_layers, layer_idx, + scale); + } + } +} + +#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \ + template void bgmv_kernel( \ + out_T * __restrict__ Y, const in_T *__restrict__ X, \ + const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \ + int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ + int64_t num_layers, int64_t layer_idx, float scale); + +#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ + INST_BGMV(narrow, wide, in_T, out_T, W_T) \ + INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/vec_dtypes.cuh b/csrc/punica/bgmv/vec_dtypes.cuh new file mode 100644 index 0000000000000..cf00d869cf635 --- /dev/null +++ b/csrc/punica/bgmv/vec_dtypes.cuh @@ -0,0 +1,1324 @@ +#ifndef VEC_DTYPES_CUH_ +#define VEC_DTYPES_CUH_ + +#include +#include +#ifdef FLASHINFER_USE_FP8 +#include +#endif +#include + +#include + +#define FLASHINFER_INLINE \ + inline __attribute__((always_inline)) __device__ __host__ + +template +struct vec_t { + FLASHINFER_INLINE float_t &operator[](size_t i); + FLASHINFER_INLINE const float_t &operator[](size_t i) const; + FLASHINFER_INLINE void fill(float_t val); + FLASHINFER_INLINE void load(const float_t *ptr); + FLASHINFER_INLINE void store(float_t *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src); + template + FLASHINFER_INLINE void cast_load(const T *ptr); + template + FLASHINFER_INLINE void cast_store(T *ptr) const; + FLASHINFER_INLINE static void memcpy(float_t *dst, const float_t *src); +}; + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + dst[i] = tgt_float_t(src[i]); + } +} + +template +FLASHINFER_INLINE void cast_load_impl(const src_float_t *src_ptr, + vec_t &dst) { + if constexpr (std::is_same::value) { + dst.load(src_ptr); + } else { + vec_t tmp; + tmp.load(src_ptr); + dst.cast_from(tmp); + } +} + +template +FLASHINFER_INLINE void cast_store_impl(const vec_t &src, + tgt_float_t *dst_ptr) { + if constexpr (std::is_same::value) { + src.store(dst_ptr); + } else { + vec_t tmp; + tmp.cast_from(src); + tmp.store(dst_ptr); + } +} + +#ifdef FLASHINFER_USE_FP8 +/******************* vec_t<__nv_fp8_e4m3> *******************/ + +// __nv_fp8_e4m3 x 1 +template <> +struct vec_t<__nv_fp8_e4m3, 1> { + __nv_fp8_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::fill(__nv_fp8_e4m3 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::load(const __nv_fp8_e4m3 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::store( + __nv_fp8_e4m3 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 1>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *dst = *src; +} + +// __nv_fp8_e4m3 x 2 +template <> +struct vec_t<__nv_fp8_e4m3, 2> { + __nv_fp8x2_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::fill(__nv_fp8_e4m3 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x2_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x2_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 2>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x2_e4m3 *)dst) = *((__nv_fp8x2_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 4 + +template <> +struct vec_t<__nv_fp8_e4m3, 4> { + __nv_fp8x4_e4m3 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::fill(__nv_fp8_e4m3 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::load(const __nv_fp8_e4m3 *ptr) { + data = *((__nv_fp8x4_e4m3 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::store( + __nv_fp8_e4m3 *ptr) const { + *((__nv_fp8x4_e4m3 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 4>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8x4_e4m3 *)dst) = *((__nv_fp8x4_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 8 + +template <> +struct vec_t<__nv_fp8_e4m3, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val); + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::fill(__nv_fp8_e4m3 val) { + ((__nv_fp8x4_e4m3 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::load(const __nv_fp8_e4m3 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::store( + __nv_fp8_e4m3 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e4m3, 8>::memcpy( + __nv_fp8_e4m3 *dst, const __nv_fp8_e4m3 *src) { + *((__nv_fp8_e4m3 *)dst) = *((__nv_fp8_e4m3 *)src); +} + +// __nv_fp8_e4m3 x 16 or more +template +struct vec_t<__nv_fp8_e4m3, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e4m3 &operator[](size_t i) { + return ((__nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e4m3 &operator[](size_t i) const { + return ((const __nv_fp8_e4m3 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e4m3 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e4m3 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e4m3 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e4m3 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e4m3 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e4m3 *dst, + const __nv_fp8_e4m3 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t<__nv_fp8_e5m2> *******************/ + +// __nv_fp8_e5m2 x 1 +template <> +struct vec_t<__nv_fp8_e5m2, 1> { + __nv_fp8_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::fill(__nv_fp8_e5m2 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::load(const __nv_fp8_e5m2 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::store( + __nv_fp8_e5m2 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 1>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *dst = *src; +} + +// __nv_fp8_e5m2 x 2 +template <> +struct vec_t<__nv_fp8_e5m2, 2> { + __nv_fp8x2_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::fill(__nv_fp8_e5m2 val) { + data.__x = + (__nv_fp8x2_storage_t(val.__x) << 8) | __nv_fp8x2_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x2_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x2_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 2>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x2_e5m2 *)dst) = *((__nv_fp8x2_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 4 + +template <> +struct vec_t<__nv_fp8_e5m2, 4> { + __nv_fp8x4_e5m2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::fill(__nv_fp8_e5m2 val) { + data.__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::load(const __nv_fp8_e5m2 *ptr) { + data = *((__nv_fp8x4_e5m2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::store( + __nv_fp8_e5m2 *ptr) const { + *((__nv_fp8x4_e5m2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 4>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8x4_e5m2 *)dst) = *((__nv_fp8x4_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 8 + +template <> +struct vec_t<__nv_fp8_e5m2, 8> { + uint2 data; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr); + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src); +}; + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::fill(__nv_fp8_e5m2 val) { + ((__nv_fp8x4_e5m2 *)(&data.x))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&data.y))->__x = (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | + __nv_fp8x4_storage_t(val.__x); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::load(const __nv_fp8_e5m2 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::store( + __nv_fp8_e5m2 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t<__nv_fp8_e5m2, 8>::memcpy( + __nv_fp8_e5m2 *dst, const __nv_fp8_e5m2 *src) { + *((__nv_fp8_e5m2 *)dst) = *((__nv_fp8_e5m2 *)src); +} + +// __nv_fp8_e5m2 x 16 or more + +template +struct vec_t<__nv_fp8_e5m2, vec_size> { + uint4 data[vec_size / 16]; + + FLASHINFER_INLINE __nv_fp8_e5m2 &operator[](size_t i) { + return ((__nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE const __nv_fp8_e5m2 &operator[](size_t i) const { + return ((const __nv_fp8_e5m2 *)data)[i]; + } + FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((__nv_fp8x4_e5m2 *)(&(data[i].x)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].y)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].z)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + ((__nv_fp8x4_e5m2 *)(&(data[i].w)))->__x = + (__nv_fp8x4_storage_t(val.__x) << 24) | + (__nv_fp8x4_storage_t(val.__x) << 16) | + (__nv_fp8x4_storage_t(val.__x) << 8) | __nv_fp8x4_storage_t(val.__x); + } + } + FLASHINFER_INLINE void load(const __nv_fp8_e5m2 *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(__nv_fp8_e5m2 *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2 *dst, + const __nv_fp8_e5m2 *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 16; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; +#endif + +/******************* vec_t *******************/ + +// half x 1 +template <> +struct vec_t { + half data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *dst = *src; +} + +// half x 2 +template <> +struct vec_t { + half2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + data = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((half2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((half2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((half2 *)dst) = *((half2 *)src); +} + +// half x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)(&data))[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)(&data))[i]; + } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half *ptr); + FLASHINFER_INLINE void store(half *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2 *)(&data.x) = make_half2(val, val); + *(half2 *)(&data.y) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half *dst, const half *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// half x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + FLASHINFER_INLINE half &operator[](size_t i) { return ((half *)data)[i]; } + FLASHINFER_INLINE const half &operator[](size_t i) const { + return ((const half *)data)[i]; + } + FLASHINFER_INLINE void fill(half val) { +#pragma unroll + for (size_t i = 0; i < vec_size; ++i) { + *(half2 *)(&(data[i].x)) = make_half2(val, val); + *(half2 *)(&(data[i].y)) = make_half2(val, val); + *(half2 *)(&(data[i].z)) = make_half2(val, val); + *(half2 *)(&(data[i].w)) = make_half2(val, val); + } + } + FLASHINFER_INLINE void load(const half *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(half *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(half *dst, const half *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// nv_bfloat16 x 1 +template <> +struct vec_t { + nv_bfloat16 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = val; +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *ptr; +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *ptr = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *dst = *src; +} + +// nv_bfloat16 x 2 +template <> +struct vec_t { + nv_bfloat162 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + data = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((nv_bfloat162 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((nv_bfloat162 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((nv_bfloat162 *)dst) = *((nv_bfloat162 *)src); +} + +// nv_bfloat16 x 4 + +template <> +struct vec_t { + uint2 data; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)(&data))[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val); + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr); + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src); +}; + +FLASHINFER_INLINE void vec_t::fill(nv_bfloat16 val) { + *(nv_bfloat162 *)(&data.x) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&data.y) = make_bfloat162(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const nv_bfloat16 *ptr) { + data = *((uint2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(nv_bfloat16 *ptr) const { + *((uint2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { + *((uint2 *)dst) = *((uint2 *)src); +} + +// nv_bfloat16 x 8 or more + +template +struct vec_t { + uint4 data[vec_size / 8]; + + FLASHINFER_INLINE nv_bfloat16 &operator[](size_t i) { + return ((nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE const nv_bfloat16 &operator[](size_t i) const { + return ((const nv_bfloat16 *)data)[i]; + } + FLASHINFER_INLINE void fill(nv_bfloat16 val) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + *(nv_bfloat162 *)(&(data[i].x)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].y)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].z)) = make_bfloat162(val, val); + *(nv_bfloat162 *)(&(data[i].w)) = make_bfloat162(val, val); + } + } + FLASHINFER_INLINE void load(const nv_bfloat16 *ptr) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + data[i] = ((uint4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(nv_bfloat16 *ptr) const { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(nv_bfloat16 *dst, + const nv_bfloat16 *src) { +#pragma unoll + for (size_t i = 0; i < vec_size / 8; ++i) { + ((uint4 *)dst)[i] = ((uint4 *)src)[i]; + } + } +}; + +/******************* vec_t *******************/ + +// float x 1 + +template <> +struct vec_t { + float data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { data = val; } + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { data = *ptr; } + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { *ptr = data; } + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *dst = *src; +} + +// float x 2 + +template <> +struct vec_t { + float2 data; + + FLASHINFER_INLINE float &operator[](size_t i) { + return ((float *)(&data))[i]; + } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(&data))[i]; + } + FLASHINFER_INLINE void fill(float val); + FLASHINFER_INLINE void load(const float *ptr); + FLASHINFER_INLINE void store(float *ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src); +}; + +FLASHINFER_INLINE void vec_t::fill(float val) { + data = make_float2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const float *ptr) { + data = *((float2 *)ptr); +} + +FLASHINFER_INLINE void vec_t::store(float *ptr) const { + *((float2 *)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(float *dst, const float *src) { + *((float2 *)dst) = *((float2 *)src); +} + +// float x 4 or more +template +struct vec_t { + float4 data[vec_size / 4]; + + FLASHINFER_INLINE float &operator[](size_t i) { return ((float *)(data))[i]; } + FLASHINFER_INLINE const float &operator[](size_t i) const { + return ((const float *)(data))[i]; + } + FLASHINFER_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float *ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4 *)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float *ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t &src) { + cast_from_impl(src, *this); + } + template + FLASHINFER_INLINE void cast_load(const T *ptr) { + cast_load_impl(ptr, *this); + } + template + FLASHINFER_INLINE void cast_store(T *ptr) const { + cast_store_impl(*this, ptr); + } + FLASHINFER_INLINE static void memcpy(float *dst, const float *src) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)dst)[i] = ((float4 *)src)[i]; + } + } +}; + +/******************* vec_t type cast *******************/ + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = __half22float2(((half2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = half(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = __float22half2_rn(((float2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((float2 *)(&dst.data))[i] = + __bfloat1622float2(((nv_bfloat162 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = nv_bfloat16(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((nv_bfloat162 *)(&dst.data))[i] = + __float22bfloat162_rn(((float2 *)(&src.data))[i]); + } + } +} + +#ifdef FLASHINFER_USE_FP8 + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e4m3 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e4m3, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e4m3 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = + __nv_fp8x4_e4m3(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e4m3, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e4m3 *)(&dst.data) = __nv_fp8x2_e4m3(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e4m3 *)(&dst.data))[i] = __nv_fp8x4_e4m3( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else if constexpr (vec_size == 2) { + *(float2 *)(&dst.data) = float2(*(__nv_fp8x2_e5m2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4 *)(&dst.data))[i] = float4(((__nv_fp8x4_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t<__nv_fp8_e5m2, vec_size> &src, + vec_t &dst) { + if constexpr (vec_size == 1) { + dst.data = float(src.data); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 2; ++i) { + ((half2 *)(&dst.data))[i] = half2(((__nv_fp8x2_e5m2 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e5m2(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(float2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = + __nv_fp8x4_e5m2(((float4 *)(&src.data))[i]); + } + } +} + +template +FLASHINFER_INLINE void cast_from_impl(const vec_t &src, + vec_t<__nv_fp8_e5m2, vec_size> &dst) { + if constexpr (vec_size == 1) { + dst.data = __nv_fp8_e4m3(src.data); + } else if constexpr (vec_size == 2) { + *(__nv_fp8x2_e5m2 *)(&dst.data) = __nv_fp8x2_e5m2(*(half2 *)(&src.data)); + } else { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + // NOTE(Zihao): need to double check if we properly handle flo and fhi + ((__nv_fp8x4_e5m2 *)(&dst.data))[i] = __nv_fp8x4_e5m2( + ((half2 *)(&src.data))[i * 2], ((half2 *)(&src.data))[i * 2 + 1]); + } + } +} + +#endif // FLASHINFER_USE_FP8 + +#endif // VEC_DTYPES_CUH_ diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc new file mode 100644 index 0000000000000..4ad46e5e1f726 --- /dev/null +++ b/csrc/punica/punica_ops.cc @@ -0,0 +1,563 @@ +#include +#include +#include + +#include + +#include "bgmv/bgmv_config.h" + +namespace { + +//====== utils ====== + +inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, + const char *a_name, const char *b_name) { + TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", + a.dim(), " vs ", b.dim()); + for (int i = 0; i < a.dim(); ++i) { + TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, + ".size(", i, ")"); + } +} + +inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { + return (uint32_t(a) << 16) | uint32_t(b); +} + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) \ + TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) + +#define CHECK_EQ(a, b) \ + TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +//====== bgmv ====== + +template +inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, + const int64_t *lora_indices, + uint16_t in_features, uint16_t out_features, + int64_t y_offset, int64_t full_y_size, + int64_t batch_size, int64_t num_layers, + int64_t layer_idx, float scale) { + switch (pack_u16(in_features, out_features)) { +#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \ + case pack_u16(feat_in, feat_out): \ + bgmv_kernel(Y, X, W, lora_indices, y_offset, \ + full_y_size, batch_size, num_layers, \ + layer_idx, scale); \ + break; +#define CASE(_in_T, _out_T, _W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \ + CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) + + FOR_BGMV_WIDE_NARROW(CASE, _, _, _) +#undef CASE +#undef CASE_ONESIDE + default: + return false; + } + + return true; +} + +void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, float scale) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t h_in = x.size(1); + int64_t h_out = y.size(1); + int64_t num_layers = w.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, 0, + h_out, B, num_layers, layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, + torch::Tensor indicies, int64_t layer_idx, + float scale, int64_t h_in, int64_t h_out, + int64_t y_offset) { + CHECK_INPUT(y); + CHECK_INPUT(x); + CHECK_INPUT(w); + CHECK_INPUT(indicies); + + CHECK_DIM(2, y); + CHECK_DIM(2, x); + CHECK_DIM(4, w); + CHECK_DIM(1, indicies); + + int64_t B = x.size(0); + int64_t num_layers = w.size(1); + int64_t full_y_size = y.size(1); + CHECK_EQ(w.size(3), h_in); + CHECK_EQ(w.size(2), h_out); + CHECK_EQ(indicies.size(0), x.size(0)); + CHECK_EQ(y.size(0), x.size(0)); + bool ok = false; + if (h_in < 65536 && h_out < 65536) { + // TODO: See if we can get rid of this massive nested switch + switch (x.scalar_type()) { + case at::ScalarType::Half: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (y.scalar_type()) { + case at::ScalarType::Half: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::BFloat16: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + case at::ScalarType::Float: + switch (w.scalar_type()) { + case at::ScalarType::Half: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + case at::ScalarType::BFloat16: + ok = launch_bgmv_kernel(static_cast(y.data_ptr()), + static_cast(x.data_ptr()), + static_cast(w.data_ptr()), + indicies.data_ptr(), h_in, h_out, + y_offset, full_y_size, B, num_layers, + layer_idx, scale); + break; + default: + break; + } + break; + default: + break; + } + break; + default: + break; + } + } + TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, + " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); +} + +} // namespace + +//====== pybind ====== + +#define DEFINE_pybind(name) m.def(#name, &name, #name); + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); + m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, + "dispatch_bgmv_low_level"); +} diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py new file mode 100644 index 0000000000000..8fdd243af69ff --- /dev/null +++ b/examples/multilora_inference.py @@ -0,0 +1,117 @@ +""" +This example shows how to use the multi-LoRA functionality for offline inference. + +Requires HuggingFace credentials for access to Llama2. +""" + +from typing import Optional, List, Tuple + +from huggingface_hub import snapshot_download + +from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput +from vllm.lora.request import LoRARequest + + +def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: + """Create a list of test prompts with their sampling parameters. + + 2 requests for base model, 4 requests for the LoRA. We define 2 + different LoRA adapters (using the same model for demo purposes). + Since we also set `max_loras=1`, the expectation is that the requests + with the second LoRA adapter will be ran after all requests with the + first adapter have finished. + """ + return [ + ("A robot may not injure a human being", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128), None), + ("To be or not to be,", + SamplingParams(temperature=0.8, + top_k=5, + presence_penalty=0.2, + max_tokens=128), None), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + SamplingParams(temperature=0.0, + logprobs=1, + prompt_logprobs=1, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora2", 2, lora_path)), + ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + SamplingParams(n=3, + best_of=3, + use_beam_search=True, + temperature=0, + max_tokens=128, + stop_token_ids=[32003]), + LoRARequest("sql-lora", 1, lora_path)), + ] + + +def process_requests(engine: LLMEngine, + test_prompts: List[Tuple[str, SamplingParams, + Optional[LoRARequest]]]): + """Continuously process a list of prompts and handle the outputs.""" + request_id = 0 + + while test_prompts or engine.has_unfinished_requests(): + if test_prompts: + prompt, sampling_params, lora_request = test_prompts.pop(0) + engine.add_request(str(request_id), + prompt, + sampling_params, + lora_request=lora_request) + request_id += 1 + + request_outputs: List[RequestOutput] = engine.step() + + for request_output in request_outputs: + if request_output.finished: + print(request_output) + + +def initialize_engine() -> LLMEngine: + """Initialize the LLMEngine.""" + # max_loras: controls the number of LoRAs that can be used in the same + # batch. Larger numbers will cause higher memory usage, as each LoRA + # slot requires its own preallocated tensor. + # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger + # numbers will cause higher memory usage. If you know that all LoRAs will + # use the same rank, it is recommended to set this as low as possible. + # max_cpu_loras: controls the size of the CPU LoRA cache. + engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", + enable_lora=True, + max_loras=1, + max_lora_rank=8, + max_cpu_loras=2, + max_num_seqs=256) + return LLMEngine.from_engine_args(engine_args) + + +def main(): + """Main function that sets up and runs the prompt processing.""" + engine = initialize_engine() + lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + test_prompts = create_test_prompts(lora_path) + process_requests(engine, test_prompts) + + +if __name__ == '__main__': + main() diff --git a/setup.py b/setup.py index fb37a8d952314..3baf27aa86532 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,16 @@ +import contextlib import io import os import re import subprocess -from typing import List, Set import warnings +from pathlib import Path +from typing import List, Set from packaging.version import parse, Version import setuptools import torch +import torch.utils.cpp_extension as torch_cpp_ext from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME ROOT_DIR = os.path.dirname(__file__) @@ -28,7 +31,7 @@ def _is_neuron() -> bool: torch_neuronx_installed = True try: subprocess.run(["neuron-ls"], capture_output=True, check=True) - except FileNotFoundError as e: + except FileNotFoundError: torch_neuronx_installed = False return torch_neuronx_installed @@ -96,10 +99,16 @@ def get_hipcc_rocm_version(): return None +def glob(pattern: str): + root = Path(__name__).parent + return [str(p) for p in root.glob(pattern)] + + def get_neuronxcc_version(): import sysconfig site_dir = sysconfig.get_paths()["purelib"] - version_file = os.path.join(site_dir, "neuronxcc", "version", "__init__.py") + version_file = os.path.join(site_dir, "neuronxcc", "version", + "__init__.py") # Check if the command was executed successfully with open(version_file, "rt") as fp: @@ -178,6 +187,8 @@ def get_torch_arch_list() -> Set[str]: "GPUs with compute capability below 7.0 are not supported.") compute_capabilities.add(f"{major}.{minor}") +ext_modules = [] + if _is_cuda(): nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) if not compute_capabilities: @@ -215,6 +226,8 @@ def get_torch_arch_list() -> Set[str]: raise RuntimeError( "CUDA 11.8 or higher is required for compute capability 9.0.") + NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() + # Add target compute capabilities to NVCC flags. for capability in compute_capabilities: num = capability[0] + capability[2] @@ -223,6 +236,14 @@ def get_torch_arch_list() -> Set[str]: NVCC_FLAGS += [ "-gencode", f"arch=compute_{num},code=compute_{num}" ] + if int(capability[0]) >= 8: + NVCC_FLAGS_PUNICA += [ + "-gencode", f"arch=compute_{num},code=sm_{num}" + ] + if capability.endswith("+PTX"): + NVCC_FLAGS_PUNICA += [ + "-gencode", f"arch=compute_{num},code=compute_{num}" + ] # Use NVCC threads to parallelize the build. if nvcc_cuda_version >= Version("11.2"): @@ -230,6 +251,36 @@ def get_torch_arch_list() -> Set[str]: num_threads = min(os.cpu_count(), nvcc_threads) NVCC_FLAGS += ["--threads", str(num_threads)] + # changes for punica kernels + NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS + REMOVE_NVCC_FLAGS = [ + '-D__CUDA_NO_HALF_OPERATORS__', + '-D__CUDA_NO_HALF_CONVERSIONS__', + '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', + '-D__CUDA_NO_HALF2_OPERATORS__', + ] + for flag in REMOVE_NVCC_FLAGS: + with contextlib.suppress(ValueError): + torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) + + install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "1"))) + device_count = torch.cuda.device_count() + for i in range(device_count): + major, minor = torch.cuda.get_device_capability(i) + if major < 8: + install_punica = False + break + if install_punica: + ext_modules.append( + CUDAExtension( + name="vllm._punica_C", + sources=["csrc/punica/punica_ops.cc"] + + glob("csrc/punica/bgmv/*.cu"), + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS_PUNICA, + }, + )) elif _is_hip(): amd_arch = get_amdgpu_offload_arch() if amd_arch not in ROCM_SUPPORTED_ARCHS: @@ -240,8 +291,6 @@ def get_torch_arch_list() -> Set[str]: elif _is_neuron(): neuronxcc_version = get_neuronxcc_version() -ext_modules = [] - vllm_extension_sources = [ "csrc/cache_kernels.cu", "csrc/attention/attention_kernels.cu", diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 174975802dc0d..1edb19c550010 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -25,6 +25,13 @@ async def step_async(self): return [RequestOutput( request_id=self.request_id)] if self.request_id else [] + async def encode_request_async( + self, + *args, + **kwargs, + ): + return [1] + def generate(self, request_id): self.request_id = request_id @@ -35,6 +42,10 @@ def add_request(self, **kwargs): del kwargs # Unused self.add_request_calls += 1 + async def add_request_async(self, **kwargs): + del kwargs # Unused + self.add_request_calls += 1 + def abort_request(self, request_id): del request_id # Unused self.abort_request_calls += 1 diff --git a/tests/lora/__init__.py b/tests/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py new file mode 100644 index 0000000000000..c1b3d04c713b5 --- /dev/null +++ b/tests/lora/conftest.py @@ -0,0 +1,143 @@ +import contextlib +import gc +import tempfile +from collections import OrderedDict +from unittest.mock import patch, MagicMock + +import pytest +import ray +import torch +import torch.nn as nn +from huggingface_hub import snapshot_download + +import vllm +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.model_loader import get_model +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + destroy_model_parallel, initialize_model_parallel) + + +def cleanup(): + destroy_model_parallel() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + torch.cuda.empty_cache() + ray.shutdown() + + +@pytest.fixture(autouse=True) +def cleanup_fixture(): + yield + cleanup() + + +@pytest.fixture +def dist_init(): + if not torch.distributed.is_initialized(): + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + torch.distributed.all_reduce(torch.zeros(1).cuda()) + initialize_model_parallel(1, 1) + yield + cleanup() + + +@pytest.fixture +def dist_init_torch_only(): + if torch.distributed.is_initialized(): + return + temp_file = tempfile.mkstemp()[1] + torch.distributed.init_process_group( + backend="nccl", + world_size=1, + rank=0, + init_method=f"file://{temp_file}", + ) + + +@pytest.fixture +def dummy_model() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", ColumnParallelLinear(50, 10)), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture +def dummy_model_gate_up() -> nn.Module: + model = nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(764, 100)), + ("dense2", RowParallelLinear(100, 50)), + ( + "layer1", + nn.Sequential( + OrderedDict([ + ("dense1", ColumnParallelLinear(100, 10)), + ("dense2", RowParallelLinear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("gate_up_proj", MergedColumnParallelLinear(50, [5, 5])), + ("outact", nn.Sigmoid()), + # Special handling for lm_head & sampler + ("lm_head", ParallelLMHead(512, 10)), + ("sampler", Sampler(512)) + ])) + model.config = MagicMock() + return model + + +@pytest.fixture(scope="session") +def sql_lora_files(): + return snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") + + +@pytest.fixture +def llama_2_7b_engine_extra_embeddings() -> nn.Module: + cleanup() + get_model_old = get_model + + def get_model_patched(model_config, lora_config=None): + return get_model_old(model_config, + LoRAConfig(max_loras=4, max_lora_rank=8)) + + with patch("vllm.worker.model_runner.get_model", get_model_patched): + engine = vllm.LLM("meta-llama/Llama-2-7b-hf", enable_lora=False) + yield engine.llm_engine + del engine + cleanup() + + +@pytest.fixture +def llama_2_7b_model_extra_embeddings( + llama_2_7b_engine_extra_embeddings) -> nn.Module: + yield llama_2_7b_engine_extra_embeddings.driver_worker.model_runner.model diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py new file mode 100644 index 0000000000000..71c671132205a --- /dev/null +++ b/tests/lora/test_layers.py @@ -0,0 +1,709 @@ +import pytest +import random +from copy import deepcopy +from dataclasses import dataclass +from typing import List, Optional, Dict, Tuple + +import torch +import torch.nn.functional as F + +from vllm.lora.layers import ( + ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + QKVParallelLinearWithLora, + VocabParallelEmbeddingWithLoRA, + RowParallelLinearWithLoRA, + SamplerWithLoRA, + LoRAMapping, + BaseLayerWithLoRA, +) +from vllm.lora.models import LoRALayerWeights, convert_mapping, PackedLoRALayerWeights +from vllm.config import LoRAConfig +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.utils import set_random_seed + +from .utils import DummyLoRAManager + +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.float32: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +def get_random_id_to_index(num_loras: int, + num_slots: int, + log: bool = True) -> List[Optional[int]]: + """Creates a random lora_id_to_index mapping. + + Args: + num_loras: The number of active loras in the mapping. + num_slots: The number of slots in the mapping. Must be larger + than num_loras. + log: Whether to log the output. + """ + + if num_loras > num_slots: + raise ValueError( + f"num_loras is higher than num_slots: {num_loras} > {num_slots}. " + "num_loras must be less than or equal to num_slots.") + + slots: List[Optional[int]] = [None] * num_slots + random_slot_selections = (torch.randperm(num_slots)[:num_loras]).tolist() + for lora_id, slot_idx in enumerate(random_slot_selections, start=1): + slots[slot_idx] = lora_id + + if log: + print(f"Created lora_id_to_index mapping: {slots}.") + + return slots + + +def populate_loras( + id_to_index: List[Optional[int]], + layer: BaseLayerWithLoRA, + layer_weights: torch.Tensor, + generate_embeddings_tensor: int = 0, + repeats: int = 1, +) -> Tuple[Dict[int, LoRALayerWeights], Dict[int, List[LoRALayerWeights]]]: + """This method populates the lora layers with lora weights. + + Args: + id_to_index: a list of lora ids. The index of the lora id + represents which memory slot the lora matrices are + stored in. A None value indicates a free slot. + layer: the LoRAlayer to populate. + layer_weights: the PyTorch tensor containing the layer's + weights. + generate_embeddings_tensor: whether to generate an + embeddings tensor for each LoRA. + repeats: must only be set for column parallel packed + layers. Indicates the number of loras to compose + together to create a single lora layer. + """ + + # Dictionary that maps the lora ID to the + # corresponding lora weights. + lora_dict: Dict[int, LoRALayerWeights] = dict() + + # Dictionary that maps the lora ID to the + # corresponding subloras. Only useful when + # repeats > 1. + sublora_dict: Dict[int, List[LoRALayerWeights]] = dict() + + for slot_idx, lora_id in enumerate(id_to_index): + if lora_id is not None: + subloras = [] + sublora_len = layer_weights.shape[0] // repeats + for i in range(repeats): + sublora = DummyLoRAManager().init_random_lora( + module_name=f"fake_{i}", + weight=layer_weights, + generate_embeddings_tensor=generate_embeddings_tensor, + ) + sublora.lora_b = sublora.lora_b[:, (sublora_len * + i):(sublora_len * (i + 1))] + sublora.optimize() + subloras.append(sublora) + + lora = PackedLoRALayerWeights.pack( + subloras) if repeats > 1 else subloras[0] + + layer.set_lora( + slot_idx, + lora_a=lora.lora_a, + lora_b=lora.lora_b, + embeddings_tensor=lora.embeddings_tensor, + ) + + lora_dict[lora_id] = lora + sublora_dict[lora_id] = subloras + + return lora_dict, sublora_dict + + +def create_random_inputs( + active_lora_ids: List[int], + num_inputs: int, + input_size: Tuple[int, ...], + input_range: Tuple[float, float], + input_type: torch.dtype = torch.int, +) -> Tuple[List[torch.Tensor], List[int], List[int]]: + """Creates random inputs. + + Args: + active_lora_ids: lora IDs of active lora weights. + num_inputs: the number of inputs to create. + input_size: the size of each individual input. + input_range: the range of values to include in the input. + input_range[0] <= possible input values < input_range[1] + input_type: the type of values in the input. + """ + + low, high = input_range + + inputs, index_mapping, prompt_mapping = [], [], [] + for _ in range(num_inputs): + if input_type == torch.int: + inputs.append( + torch.randint(low=int(low), + high=int(high), + size=input_size, + device="cuda")) + else: + inputs.append( + torch.rand(size=input_size, dtype=input_type, device="cuda") * + high + low) + + lora_id = random.choice(active_lora_ids) + index_mapping += [lora_id] * input_size[0] + prompt_mapping += [lora_id] + + return inputs, index_mapping, prompt_mapping + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings(dist_init, num_loras) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding.weight.data = torch.rand_like(embedding.weight.data) + embedding.weight.data[512:, :] = 0 + lora_embedding = VocabParallelEmbeddingWithLoRA(embedding) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + embedding, lora_embedding = create_random_embedding_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=embedding.weight.T, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info) + + lora_result = lora_embedding(torch.cat(inputs)) + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = embedding(input_) + after_a = F.embedding( + input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(inputs)) + expected_result = embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +# @pytest.mark.skip(reason="Fails when loras are in any slot other than the first.") +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_embeddings_with_new_embeddings(dist_init, num_loras) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_embedding_layer(): + embedding = VocabParallelEmbedding(512, 256) + embedding_data = torch.rand_like(embedding.weight.data) + embedding.weight.data = embedding_data + embedding.weight.data[512:, :] = 0 + expanded_embedding = VocabParallelEmbedding( + 512 + lora_config.lora_extra_vocab_size * max_loras, + 256, + org_num_embeddings=512) + expanded_embedding.weight.data[:512, :] = embedding_data + # We need to deepcopy the embedding as it will be modifed + # in place + lora_embedding = VocabParallelEmbeddingWithLoRA( + deepcopy(expanded_embedding)) + lora_embedding.create_lora_weights(max_loras, lora_config) + + return expanded_embedding, lora_embedding + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + expanded_embedding, lora_embedding = create_random_embedding_layer() + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_embedding, + layer_weights=torch.zeros( + (256, 512 + lora_config.lora_extra_vocab_size)), + generate_embeddings_tensor=256, + ) + + # All embeddings tensors have the same shape. + embeddings_tensors = [ + lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys()) + ] + embeddings_tensor_len = embeddings_tensors[0].shape[0] + + # Add empty embeddings_tensors for unoccupied lora slots. + for _ in range(max_loras - len(embeddings_tensors)): + embeddings_tensors.append( + torch.zeros(embeddings_tensors[0].shape, device="cuda")) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + # Force some of the inputs to be in the extended embeddings range + # to guarantee that their behavior is tested. + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + embedding_id = lora_id - 1 + input_[-1] = 512 + (embedding_id * embeddings_tensor_len) + original_input_[-1] = 512 + input_[-2] = 512 + ((embedding_id + 1) * embeddings_tensor_len - 1) + original_input_[-2] = 512 + embeddings_tensor_len - 1 + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + expanded_embedding.weight[512:512 + + (embeddings_tensor_len * + max_loras)] = torch.cat(embeddings_tensors) + + lora_result = lora_embedding(torch.cat(original_inputs)) + + expected_results = [] + for input_, original_input_, lora_id in zip(inputs, original_inputs, + prompt_mapping): + lora = lora_dict[lora_id] + result = expanded_embedding(input_) + after_a = F.embedding( + original_input_, + lora.lora_a, + ) + result += (after_a @ lora.lora_b) + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_embedding.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=num_loras * 3, + input_size=(200, ), + input_range=(1, 512), + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + original_inputs = deepcopy(inputs) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_embedding.set_mapping(*mapping_info, ) + + lora_result = lora_embedding(torch.cat(original_inputs)) + expected_result = expanded_embedding(torch.cat(inputs)) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +def test_lm_head_sampler(dist_init, num_loras) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_sampler_layer(): + linear = ParallelLMHead(32000 + lora_config.lora_extra_vocab_size, + 1024, 32000) + linear.weight.data = torch.rand_like(linear.weight.data) + linear.weight.data[:, 32000:] = 0 + sampler = Sampler(32000 + lora_config.lora_extra_vocab_size, 32000) + lora_sampler = SamplerWithLoRA(sampler, 1024, linear.weight.dtype, + linear.weight.device) + lora_sampler.create_lora_weights(max_loras, lora_config) + + return linear, sampler, lora_sampler + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, sampler, lora_sampler = create_random_sampler_layer() + + # NOTE: all the generated loras share the same embeddings tensor. + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_sampler, + layer_weights=linear.weight, + generate_embeddings_tensor=1024, + ) + embeddings_tensor = list(lora_dict.values())[0].embeddings_tensor + embeddings_tensor_len = embeddings_tensor.shape[0] + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=8 * num_loras, # * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + input_ = torch.rand(20, 1024, device="cuda") + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 32000, + lora_config.lora_extra_vocab_size, + ) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=linear.weight, + embedding_bias=None) + + original_weight = linear.weight.clone() + + linear.weight[sampler.org_vocab_size:sampler.org_vocab_size + + embeddings_tensor_len] = embeddings_tensor + + sampler.org_vocab_size = 32000 + lora_config.lora_extra_vocab_size + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = sampler._get_logits(hidden_states=input_, + embedding=linear.weight, + embedding_bias=None) + result[:, 32000 + embeddings_tensor_len:] = float("-inf") + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + sampler.org_vocab_size = 32000 + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_sampler.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=8 * num_loras * 3, + input_size=(1, 1024), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 32000, + lora_config.lora_extra_vocab_size) + lora_sampler.set_mapping(*mapping_info, ) + + lora_result = lora_sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None)[:, :32000] + expected_result = sampler._get_logits(hidden_states=torch.cat(inputs), + embedding=original_weight, + embedding_bias=None) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("orientation", ["row", "column"]) +def test_linear_parallel(dist_init, num_loras, orientation) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_random_linear_parallel_layer(): + if orientation == "row": + linear = RowParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = RowParallelLinearWithLoRA(linear) + else: + linear = ColumnParallelLinear(4096, 4096, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = ColumnParallelLinearWithLoRA(linear) + lora_linear.create_lora_weights(max_loras, lora_config) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + linear, lora_linear = create_random_linear_parallel_layer() + + lora_dict, _ = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + lora = lora_dict[lora_id] + result = linear(input_)[0] + result += input_ @ lora.lora_a @ lora.lora_b * lora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + # Check that resetting the lora weights succeeds + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras, + 512, lora_config.lora_extra_vocab_size) + lora_linear.set_mapping(*mapping_info, ) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + +@torch.inference_mode() +@pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) +@pytest.mark.parametrize("repeats", [2, 3]) +def test_column_parallel_packed(dist_init, num_loras, repeats) -> None: + + max_loras = 8 + lora_config = LoRAConfig(max_loras=max_loras, + max_lora_rank=8, + lora_dtype=torch.float16) + + def create_column_parallel_packed_layer(): + if repeats == 2: + linear = MergedColumnParallelLinear(4096, [4096] * repeats, + bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = MergedColumnParallelLinearWithLoRA(linear) + else: + linear = QKVParallelLinear(4096, 64, 32, bias=False) + linear.weight.data = torch.rand_like(linear.weight.data) + lora_linear = QKVParallelLinearWithLora(linear) + + @dataclass + class FakeConfig: + hidden_size = 4096 + num_key_value_heads = 32 + num_attention_heads = 32 + + lora_linear.create_lora_weights(max_loras, + lora_config, + model_config=FakeConfig()) + + return linear, lora_linear + + for i in range(10): + set_random_seed(i) + + id_to_index = get_random_id_to_index(num_loras, max_loras) + + linear, lora_linear = create_column_parallel_packed_layer() + + lora_dict, sublora_dict = populate_loras( + id_to_index, + layer=lora_linear, + layer_weights=linear.weight, + repeats=repeats, + ) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=list(lora_dict.keys()), + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + + expected_results = [] + for input_, lora_id in zip(inputs, prompt_mapping): + result = linear(input_)[0] + subloras = sublora_dict[lora_id] + for i, sublora in enumerate(subloras): + result[:, sublora.lora_b.shape[1] * i:sublora.lora_b.shape[1] * ( + i + 1 + )] += input_ @ sublora.lora_a @ sublora.lora_b * sublora.scaling + expected_results.append(result) + expected_result = torch.cat(expected_results) + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) + + for slot_idx in range(max_loras): + lora_linear.reset_lora(slot_idx) + + inputs, index_mapping, prompt_mapping = create_random_inputs( + active_lora_ids=[0], + num_inputs=32 * num_loras, + input_size=(1, 4096), + input_range=(0, 1), + input_type=torch.float32, + ) + lora_mapping = LoRAMapping(index_mapping, prompt_mapping) + + mapping_info = convert_mapping( + lora_mapping, + id_to_index, + max_loras, + 512, + lora_config.lora_extra_vocab_size, + ) + lora_linear.set_mapping(*mapping_info) + + lora_result = lora_linear(torch.cat(inputs))[0] + expected_result = linear(torch.cat(inputs))[0] + + rtol, atol = TOLERANCES[lora_result.dtype] + assert torch.allclose(lora_result, + expected_result, + rtol=rtol, + atol=atol) diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py new file mode 100644 index 0000000000000..06fbf19eea824 --- /dev/null +++ b/tests/lora/test_llama.py @@ -0,0 +1,144 @@ +import pytest +import ray + +import vllm +from vllm.lora.request import LoRARequest +from .conftest import cleanup + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + + +def do_sample(llm, lora_path: str, lora_id: int): + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]" + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@pytest.mark.parametrize("tp_size", [1]) +def test_llama_lora(sql_lora_files, tp_size): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < tp_size: + # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=tp_size) + + expected_no_lora_output = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the womens doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", + ] + expected_lora_output = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " + ] + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output + + print("removing lora") + + +@pytest.mark.skip("Requires multiple GPUs") +def test_llama_tensor_parallel_equality(sql_lora_files): + # Cannot use as it will initialize torch.cuda too early... + # if torch.cuda.device_count() < 4: + # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") + + llm_tp1 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1) + output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) + + del llm_tp1 + cleanup() + + llm_tp2 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=2) + output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) + + del llm_tp2 + cleanup() + + assert output_tp1 == output_tp2 + + llm_tp4 = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4) + output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) + + del llm_tp4 + cleanup() + + assert output_tp1 == output_tp4 + + +def test_llama_lora_warmup(sql_lora_files): + """Test that the LLM initialization works with a warmup LORA path and is more conservative""" + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_lora(): + llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) + num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_lora_warmup + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_no_lora(): + llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) + num_gpu_blocks_no_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_no_lora_warmup + + num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) + num_gpu_blocks_no_lora_warmup = ray.get( + get_num_gpu_blocks_no_lora.remote()) + assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( + "The warmup with lora should be more" + " conservative than without lora, therefore the number of memory blocks for the KV cache should be " + "less when using lora than when not using lora") diff --git a/tests/lora/test_lora.py b/tests/lora/test_lora.py new file mode 100644 index 0000000000000..3415d36b7e341 --- /dev/null +++ b/tests/lora/test_lora.py @@ -0,0 +1,224 @@ +import pytest +import torch + +from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice + +from .utils import DummyLoRAManager + +TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4] +QKV_TENSOR_SIZES = [ + (8192, 1024, 1024), + (8192 // 8, 1024 // 8, 1024 // 8), + (4096, 4096, 4096), + (4096 // 2, 4096 // 2, 4096 // 2), +] +BATCH_SIZES = [8, 32, 256] +RANKS = [8] +DTYPES = [torch.float16] +TOLERANCES = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), +} + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora(m, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name, weight, rank=rank) + lora = manager.get_module_lora(module_name) + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = input @ lora.lora_a @ lora.lora_b * lora.scaling + + lora_a_stack = torch.zeros(8, + 1, + lora.lora_a.shape[1], + lora.lora_a.shape[0], + device="cuda", + dtype=dtype) + lora_b_stack = torch.zeros(8, + 1, + lora.lora_b.shape[1], + lora.lora_b.shape[0], + device="cuda", + dtype=dtype) + for i in range(lora_a_stack.shape[0]): + lora_a_stack[i][0] = lora.lora_a.T + lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora( + input, lora_a_stack, lora_b_stack, + torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"), + output) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora(input, lora_a_stack, lora_b_stack, + torch.full((len(input), ), -1, device="cuda"), output) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("m", TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None: + if m % 2 != 0: + pytest.skip("m must be divisible by 2") + if m // 2 not in TENSOR_SIZES: + pytest.skip("m//2 must be in TENSOR_SIZES") + + manager = DummyLoRAManager() + + module_name = "module" + weight = torch.rand([m // 2, n], device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "1", weight, rank=rank) + lora_1 = manager.get_module_lora(module_name + "1") + manager.init_random_lora(module_name + "2", weight, rank=rank) + lora_2 = manager.get_module_lora(module_name + "2") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling, + input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_a.shape[1], + lora_1.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_1.lora_b.shape[1], + lora_1.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_1.lora_a.T + lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T + lora_a_stacks[1][i][0] = lora_2.lora_a.T + lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T + + output = torch.zeros(k, m, device="cuda", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, (m // 2, m // 2)) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, (m // 2, m // 2)) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() + + +@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES) +@pytest.mark.parametrize("n", TENSOR_SIZES) +@pytest.mark.parametrize("k", BATCH_SIZES) +@pytest.mark.parametrize("rank", RANKS) +@pytest.mark.parametrize("dtype", DTYPES) +def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None: + manager = DummyLoRAManager() + + module_name = "module" + weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype) + weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype) + + manager.init_random_lora(module_name + "q", weight_q, rank=rank) + lora_q = manager.get_module_lora(module_name + "q") + manager.init_random_lora(module_name + "k", weight_kv, rank=rank) + lora_k = manager.get_module_lora(module_name + "k") + manager.init_random_lora(module_name + "v", weight_kv, rank=rank) + lora_v = manager.get_module_lora(module_name + "v") + + input = torch.rand(k, n, device="cuda", dtype=dtype) + expected = torch.cat([ + input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling, + input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling, + input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling + ], + dim=1) + + lora_a_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_a.shape[1], + lora_q.lora_a.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_a.shape[1], + lora_k.lora_a.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + lora_b_stacks = [ + torch.zeros(8, + 1, + lora_q.lora_b.shape[1], + lora_q.lora_b.shape[0], + device="cuda", + dtype=dtype) + ] + [ + torch.zeros(8, + 1, + lora_k.lora_b.shape[1], + lora_k.lora_b.shape[0], + device="cuda", + dtype=dtype) for i in range(2) + ] + for i in range(lora_a_stacks[0].shape[0]): + lora_a_stacks[0][i][0] = lora_q.lora_a.T + lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T + lora_a_stacks[1][i][0] = lora_k.lora_a.T + lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T + lora_a_stacks[2][i][0] = lora_v.lora_a.T + lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T + + output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype) + _apply_lora_packed_nslice( + input, lora_a_stacks, lora_b_stacks, + torch.randint(0, + lora_a_stacks[0].shape[0], (len(input), ), + device="cuda"), output, (qkv[0], qkv[1], qkv[2])) + + rtol, atol = TOLERANCES[dtype] + assert torch.allclose(expected, output, rtol=rtol, atol=atol) + + output[:] = 0 + _apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks, + torch.full((len(input), ), -1, device="cuda"), + output, (qkv[0], qkv[1], qkv[2])) + assert torch.allclose(torch.zeros_like(output), output) + + manager.reset_lora() diff --git a/tests/lora/test_lora_manager.py b/tests/lora/test_lora_manager.py new file mode 100644 index 0000000000000..78a4a5bc5ecd2 --- /dev/null +++ b/tests/lora/test_lora_manager.py @@ -0,0 +1,475 @@ +import os +from typing import List + +import pytest +import torch +from safetensors.torch import load_file +from torch import nn + +from vllm.config import LoRAConfig +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + RowParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA) +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.models import (EMBEDDING_MODULES, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, LoRAMapping) +from vllm.lora.request import LoRARequest +from vllm.lora.worker_manager import (LRUCacheWorkerLoRAManager, + WorkerLoRAManager) +from vllm.model_executor.layers.linear import RowParallelLinear + + +def test_from_lora_tensors(sql_lora_files): + tensors = load_file( + os.path.join(sql_lora_files, "adapter_model.safetensors")) + new_embeddings = load_file( + os.path.join(sql_lora_files, "new_embeddings.safetensors")) + lora_model = LoRAModel.from_lora_tensors(1, + 8, + 16, + tensors, + "cuda", + embeddings=new_embeddings) + for module_name, lora in lora_model.loras.items(): + assert lora.module_name == module_name + assert lora.rank == 8 + assert lora.lora_alpha == 16 + assert lora.lora_a is not None + assert lora.lora_b is not None + assert (lora.lora_a.shape[1] == lora.lora_b.shape[0] + ), f"{lora.lora_a.shape=}, {lora.lora_b.shape=}" + assert lora.lora_a.shape[1] == 8 + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), None) + if embeddings_module: + assert torch.equal( + lora.embeddings_tensor, + new_embeddings[EMBEDDING_MODULES[embeddings_module]].to( + device=lora.embeddings_tensor.device)) + else: + assert lora.embeddings_tensor is None + + +def create_lora(lora_id: int, model: nn.Module, + sub_modules: List[str]) -> LoRAModel: + loras = {} + for name in sub_modules: + w = model.get_submodule(name).weight + loras[name] = LoRALayerWeights( + name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0]], device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def create_packed_lora( + lora_id: int, + model: nn.Module, + module_name, + replaced_module_names, + empty_replaced_module_name=None, +) -> LoRAModel: + w = model.get_submodule(module_name).weight + loras = {} + for replaced_module_name in replaced_module_names: + if replaced_module_name == empty_replaced_module_name: + continue + loras[replaced_module_name] = LoRALayerWeights( + replaced_module_name, + 8, + 16, + torch.rand([w.shape[1], 8], device="cuda"), + torch.rand([8, w.shape[0] // len(replaced_module_names)], + device="cuda"), + ) + return LoRAModel(lora_id, 8, loras) + + +def test_replace_submodules(dist_init, dummy_model): + model = dummy_model + manager = LoRAModelManager(model, + 1, + 1, + 1, + LoRAConfig(max_lora_rank=8, + max_cpu_loras=8, + max_loras=8), + lora_target_modules=["dense1", "layer1.dense2"]) + model = manager.model + + assert isinstance(model.get_submodule("dense1"), + ColumnParallelLinearWithLoRA) + assert isinstance(model.get_submodule("layer1.dense1"), + ColumnParallelLinearWithLoRA) + assert isinstance(model.get_submodule("dense2"), RowParallelLinear) + assert isinstance(model.get_submodule("layer1.dense2"), + RowParallelLinearWithLoRA) + + +def test_lora_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_index_to_id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_index_to_id[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + with pytest.raises(ValueError): + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_index_to_id[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] is None + assert manager.add_lora(model_lora2) + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] is None + assert manager.activate_lora(2) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 + + +def test_lora_lru_cache_model_manager(dist_init, dummy_model): + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, + 2, + 2, + 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=3, max_loras=2), + lora_target_modules=["dense1", "dense2", "lm_head"]) + assert all(x is None for x in manager.lora_index_to_id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_index_to_id[0] == 1 + assert not manager.add_lora(model_lora1) + assert not manager.activate_lora(1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + assert not manager.add_lora(model_lora2) + assert not manager.activate_lora(2) + assert manager.add_lora(model_lora3) + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 + assert manager.remove_lora(model_lora2.id) + assert manager.lora_index_to_id[1] is None + assert not manager.remove_lora(model_lora2.id) + assert manager.remove_lora(model_lora1.id) + assert not manager.remove_lora(model_lora1.id) + assert manager.add_lora(model_lora1) + assert manager.activate_lora(1) + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 1 + assert manager.add_lora(model_lora2) + assert manager.deactivate_lora(3) + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 1 + assert manager.activate_lora(2) + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 1 + assert manager.activate_lora(3) + assert manager.lora_index_to_id[0] == 2 + assert manager.lora_index_to_id[1] == 3 + + +def test_lru_lora_model_manager(dist_init, dummy_model): + # This tests just the LRU cache functionality, everything else is + # tested in test_lora_model_manager + model = dummy_model + model_lora1 = create_lora(1, model, ["layer1.dense1", "dense2", "lm_head"]) + model_lora2 = create_lora(2, model, ["dense1", "dense2", "lm_head"]) + model_lora3 = create_lora(3, model, ["dense1", "dense2", "lm_head"]) + model_lora4 = create_lora(4, model, ["dense1", "dense2", "lm_head"]) + manager = LRUCacheLoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["dense1", "dense2", "lm_head"]) + + assert all(x is None for x in manager.lora_index_to_id) + + # Add up to capacity + assert manager.add_lora(model_lora1) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(1) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {1, 2} + assert manager.lora_index_to_id[0] == 1 + assert manager.lora_index_to_id[1] == 2 + + # Add over capacity + assert manager.add_lora(model_lora3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(3) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 4 + + # Add 3 again to move it to the top and then add 2 + # should return false since it's in already + assert not manager.add_lora(model_lora3) + assert not manager.activate_lora(3) + assert manager.add_lora(model_lora2) + assert manager.activate_lora(2) + + assert set(manager.list_loras()) == {3, 2} + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 2 + + # Remove manually + assert manager.remove_lora(3) + assert not manager.remove_lora(3) + + assert set(manager.list_loras()) == {2} + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 2 + + assert manager.add_lora(model_lora3) + assert manager.activate_lora(3) + assert manager.add_lora(model_lora4) + assert manager.activate_lora(4) + + assert set(manager.list_loras()) == {3, 4} + assert manager.lora_index_to_id[0] == 3 + assert manager.lora_index_to_id[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == {4} + assert manager.lora_index_to_id[0] is None + assert manager.lora_index_to_id[1] == 4 + + assert manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_index_to_id) + + assert not manager.remove_oldest_lora() + assert set(manager.list_loras()) == set() + assert all(x is None for x in manager.lora_index_to_id) + + +def test_lru_cache_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = LRUCacheWorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 3, 4} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 3 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 4, 5} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 4 + + worker_lora_manager.set_active_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 6, 7, 8} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 7 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 8 + assert worker_lora_manager._lora_manager.lora_index_to_id[3] == 6 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.set_active_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_worker_lora_manager(llama_2_7b_model_extra_embeddings, + sql_lora_files): + # Should remove every LoRA not specified in the request. + lora_config = LoRAConfig(max_lora_rank=8, max_cpu_loras=4, max_loras=4) + worker_lora_manager = WorkerLoRAManager( + 4, 2, llama_2_7b_model_extra_embeddings.config.vocab_size, lora_config, + torch.device("cuda")) + worker_lora_manager.create_lora_manager(llama_2_7b_model_extra_embeddings) + + mapping = LoRAMapping([], []) + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("3", 3, sql_lora_files), + LoRARequest("4", 4, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 3, 4} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 3 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 4 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("2", 2, sql_lora_files), + LoRARequest("5", 5, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1, 2, 5} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 2 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 5 + + worker_lora_manager.set_active_loras([ + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files), + LoRARequest("1", 1, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {1} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 1 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] is None + assert worker_lora_manager._lora_manager.lora_index_to_id[2] is None + + worker_lora_manager.set_active_loras([ + LoRARequest("6", 6, sql_lora_files), + LoRARequest("7", 7, sql_lora_files), + LoRARequest("8", 8, sql_lora_files) + ], mapping) + assert worker_lora_manager.list_loras() == {6, 7, 8} + assert worker_lora_manager._lora_manager.lora_index_to_id[0] == 8 + assert worker_lora_manager._lora_manager.lora_index_to_id[1] == 6 + assert worker_lora_manager._lora_manager.lora_index_to_id[2] == 7 + + # Over capacity + with pytest.raises(RuntimeError): + worker_lora_manager.set_active_loras([ + LoRARequest("10", 10, sql_lora_files), + LoRARequest("11", 11, sql_lora_files), + LoRARequest("12", 12, sql_lora_files), + LoRARequest("13", 13, sql_lora_files), + LoRARequest("14", 14, sql_lora_files) + ], mapping) + + +def test_packed_loras(dist_init, dummy_model_gate_up): + model = dummy_model_gate_up + model_lora = create_packed_lora( + 1, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"]) + model_lora1 = create_packed_lora( + 2, + model, + module_name="gate_up_proj", + replaced_module_names=["gate_proj", "up_proj"], + empty_replaced_module_name="gate_proj", + ) + + manager = LoRAModelManager( + model, 2, 2, 2, + LoRAConfig(max_lora_rank=8, max_cpu_loras=2, max_loras=2), + ["gate_up_proj"]) + model = manager.model + + assert isinstance(model.get_submodule("gate_up_proj"), + MergedColumnParallelLinearWithLoRA) + assert manager.add_lora(model_lora) + assert manager.add_lora(model_lora1) + + packed_lora = model_lora.get_lora("gate_up_proj") + assert packed_lora and isinstance(packed_lora, PackedLoRALayerWeights) + + assert torch.allclose(packed_lora.lora_a[0], + model_lora.get_lora("gate_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[0], + model_lora.get_lora("gate_proj").lora_b) + assert torch.allclose(packed_lora.lora_a[1], + model_lora.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora.lora_b[1], + model_lora.get_lora("up_proj").lora_b) + + packed_lora1 = model_lora1.get_lora("gate_up_proj") + assert packed_lora1 and isinstance(packed_lora1, PackedLoRALayerWeights) + + assert packed_lora1.lora_a[0] is None + assert packed_lora1.lora_b[0] is None + assert torch.allclose(packed_lora1.lora_a[1], + model_lora1.get_lora("up_proj").lora_a) + assert torch.allclose(packed_lora1.lora_b[1], + model_lora1.get_lora("up_proj").lora_b) diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py new file mode 100644 index 0000000000000..903814faa5dc7 --- /dev/null +++ b/tests/lora/test_punica.py @@ -0,0 +1,175 @@ +# Based on code from https://github.com/punica-ai/punica + +import pytest +import torch + +import vllm.lora.punica as punica + + +def assert_close(a, b): + rtol, atol = { + torch.float16: (5e-3, 5e-3), + torch.bfloat16: (3e-2, 2e-2), + torch.float32: (None, None), + }[a.dtype] + torch.testing.assert_close(a, b, rtol=rtol, atol=atol) + + +def _lora_ref_impl( + y_final: torch.Tensor, + x: torch.Tensor, + wa_T_all: torch.Tensor, + wb_T_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, +): + y_stage_1 = torch.empty( + (x.size(0), wa_T_all.size(-2)), + dtype=torch.float32, + device=x.device, + ) + bs = x.shape[0] + s = torch.tensor(scale, dtype=torch.float32, device=x.device) + for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): + xi = x[i].unsqueeze(0).to(torch.float32) + wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + + tmp = xi @ wa + y_stage_1[i] = tmp.squeeze(0) + y_final[i] += (tmp @ wb).squeeze(0) * s + return y_final, y_stage_1 + + +H1 = H2 = [ + 128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, + 5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000, + 32256, 32512, 32768, 33024 +] +SEED = [0xabcdabcd987] + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness(dtype_str, h1, h2, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all = torch.randn(num_loras, + num_layers, + h2, + r, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale) + + y_our = y.clone() + punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx, + scale) + + assert_close(y_ref, y_our) + + +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("h2", H2) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_correctness_slice(dtype_str, h1, h2, seed): + if h2 % 3 != 0 or h2 // 3 not in H1: + pytest.skip("h2 must be divisible by 3 and in supported shapes") + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + r = 8 + bs = 32 + scale = 0.123 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all_0 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_1 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wa_T_all_2 = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + wb_T_all_0 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_1 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + wb_T_all_2 = torch.randn(num_loras, + num_layers, + h2 // 3, + r, + dtype=dtype, + device=device) + + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, h2, dtype=dtype, device=device) + s = h2 // 3 + + y_ref = y.clone() + _lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale) + _lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale) + + y_our = y.clone() + punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices, + layer_idx, scale, 0, s) + punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices, + layer_idx, scale, s, s) + punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices, + layer_idx, scale, s * 2, s) + + assert_close(y_ref[:, :s], y_our[:, :s]) + assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2]) + assert_close(y_ref[:, s * 2:], y_our[:, s * 2:]) diff --git a/tests/lora/test_tokenizer.py b/tests/lora/test_tokenizer.py new file mode 100644 index 0000000000000..6c4c91fce8127 --- /dev/null +++ b/tests/lora/test_tokenizer.py @@ -0,0 +1,69 @@ +import pytest +from transformers import AutoTokenizer, PreTrainedTokenizerBase + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer import TokenizerGroup, get_lora_tokenizer + + +@pytest.mark.asyncio +async def test_transformers_tokenizer(): + reference_tokenizer = AutoTokenizer.from_pretrained("gpt2") + tokenizer = TokenizerGroup( + tokenizer_id="gpt2", + enable_lora=False, + max_num_seqs=1, + max_input_length=None, + ) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=None) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=None) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + +@pytest.mark.asyncio +async def test_transformers_tokenizer_lora(sql_lora_files): + reference_tokenizer = AutoTokenizer.from_pretrained(sql_lora_files) + tokenizer = TokenizerGroup( + tokenizer_id="gpt2", + enable_lora=True, + max_num_seqs=1, + max_input_length=None, + ) + lora_request = LoRARequest("1", 1, sql_lora_files) + assert reference_tokenizer.encode("prompt") == tokenizer.encode( + request_id="request_id", prompt="prompt", lora_request=lora_request) + assert reference_tokenizer.encode( + "prompt") == await tokenizer.encode_async(request_id="request_id", + prompt="prompt", + lora_request=lora_request) + assert isinstance(tokenizer.get_lora_tokenizer(None), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + None) == await tokenizer.get_lora_tokenizer_async(None) + + assert isinstance(tokenizer.get_lora_tokenizer(lora_request), + PreTrainedTokenizerBase) + assert tokenizer.get_lora_tokenizer( + lora_request) != tokenizer.get_lora_tokenizer(None) + assert tokenizer.get_lora_tokenizer( + lora_request) == await tokenizer.get_lora_tokenizer_async(lora_request) + + +def test_get_lora_tokenizer(sql_lora_files, tmpdir): + lora_request = None + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer + + lora_request = LoRARequest("1", 1, sql_lora_files) + tokenizer = get_lora_tokenizer(lora_request) + assert tokenizer.get_added_vocab() + + lora_request = LoRARequest("1", 1, str(tmpdir)) + tokenizer = get_lora_tokenizer(lora_request) + assert not tokenizer diff --git a/tests/lora/test_utils.py b/tests/lora/test_utils.py new file mode 100644 index 0000000000000..2996322f4aa48 --- /dev/null +++ b/tests/lora/test_utils.py @@ -0,0 +1,172 @@ +from collections import OrderedDict + +from torch import nn + +from vllm.utils import LRUCache +from vllm.lora.utils import (parse_fine_tuned_lora_name, replace_submodule) + + +def test_parse_fine_tuned_lora_name(): + fixture = { + ("base_model.model.lm_head.lora_A.weight", "lm_head", True), + ("base_model.model.lm_head.lora_B.weight", "lm_head", False), + ( + "base_model.model.model.embed_tokens.lora_embedding_A", + "model.embed_tokens", + True, + ), + ( + "base_model.model.model.embed_tokens.lora_embedding_B", + "model.embed_tokens", + False, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_A.weight", + "model.layers.9.mlp.down_proj", + True, + ), + ( + "base_model.model.model.layers.9.mlp.down_proj.lora_B.weight", + "model.layers.9.mlp.down_proj", + False, + ), + } + for name, module_name, is_lora_a in fixture: + assert (module_name, is_lora_a) == parse_fine_tuned_lora_name(name) + + +def test_replace_submodule(): + model = nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(764, 100)), + ("act1", nn.ReLU()), + ("dense2", nn.Linear(100, 50)), + ( + "seq1", + nn.Sequential( + OrderedDict([ + ("dense1", nn.Linear(100, 10)), + ("dense2", nn.Linear(10, 50)), + ])), + ), + ("act2", nn.ReLU()), + ("output", nn.Linear(50, 10)), + ("outact", nn.Sigmoid()), + ])) + + sigmoid = nn.Sigmoid() + + replace_submodule(model, "act1", sigmoid) + assert dict(model.named_modules())["act1"] == sigmoid + + dense2 = nn.Linear(1, 5) + replace_submodule(model, "seq1.dense2", dense2) + assert dict(model.named_modules())["seq1.dense2"] == dense2 + + +class TestLRUCache(LRUCache): + + def _on_remove(self, key, value): + if not hasattr(self, "_remove_counter"): + self._remove_counter = 0 + self._remove_counter += 1 + + +def test_lru_cache(): + cache = TestLRUCache(3) + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(1, 1) + assert len(cache) == 1 + + cache.put(2, 2) + assert len(cache) == 2 + + cache.put(3, 3) + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache.put(4, 4) + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache.get(2) == 2 + + cache.put(5, 5) + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + assert cache.pop(5) == 5 + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.get(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.put(6, 6) + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache + + cache.remove_oldest() + assert len(cache) == 2 + assert set(cache.cache) == {2, 6} + assert cache._remove_counter == 4 + + cache.clear() + assert len(cache) == 0 + assert cache._remove_counter == 6 + + cache._remove_counter = 0 + + cache[1] = 1 + assert len(cache) == 1 + + cache[1] = 1 + assert len(cache) == 1 + + cache[2] = 2 + assert len(cache) == 2 + + cache[3] = 3 + assert len(cache) == 3 + assert set(cache.cache) == {1, 2, 3} + + cache[4] = 4 + assert len(cache) == 3 + assert set(cache.cache) == {2, 3, 4} + assert cache._remove_counter == 1 + assert cache[2] == 2 + + cache[5] = 5 + assert set(cache.cache) == {2, 4, 5} + assert cache._remove_counter == 2 + + del cache[5] + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache.pop(10) + assert len(cache) == 2 + assert set(cache.cache) == {2, 4} + assert cache._remove_counter == 3 + + cache[6] = 6 + assert len(cache) == 3 + assert set(cache.cache) == {2, 4, 6} + assert 2 in cache + assert 4 in cache + assert 6 in cache diff --git a/tests/lora/test_worker.py b/tests/lora/test_worker.py new file mode 100644 index 0000000000000..68c2c0b5fc134 --- /dev/null +++ b/tests/lora/test_worker.py @@ -0,0 +1,61 @@ +import os +import random +import tempfile +from unittest.mock import patch + +from vllm.lora.models import LoRAMapping +from vllm.lora.request import LoRARequest +from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig, LoRAConfig +from vllm.worker.worker import Worker + + +@patch.dict(os.environ, {"RANK": "0"}) +def test_worker_apply_lora(sql_lora_files): + worker = Worker( + model_config=ModelConfig( + "meta-llama/Llama-2-7b-hf", + "meta-llama/Llama-2-7b-hf", + tokenizer_mode="auto", + trust_remote_code=False, + download_dir=None, + load_format="dummy", + seed=0, + dtype="float16", + revision=None, + ), + parallel_config=ParallelConfig(1, 1, False), + scheduler_config=SchedulerConfig(32, 32, 32, 256), + local_rank=0, + rank=0, + lora_config=LoRAConfig(max_lora_rank=8, max_cpu_loras=32, + max_loras=32), + distributed_init_method=f"file://{tempfile.mkstemp()[1]}", + ) + worker.init_model() + worker.load_model() + + worker.model_runner.set_active_loras([], LoRAMapping([], [])) + assert worker.list_loras() == set() + + n_loras = 32 + lora_requests = [ + LoRARequest(str(i + 1), i + 1, sql_lora_files) for i in range(n_loras) + ] + + worker.model_runner.set_active_loras(lora_requests, LoRAMapping([], [])) + assert worker.list_loras() == { + lora_request.lora_int_id + for lora_request in lora_requests + } + + for i in range(32): + random.seed(i) + iter_lora_requests = random.choices(lora_requests, + k=random.randint(1, n_loras)) + random.shuffle(iter_lora_requests) + iter_lora_requests = iter_lora_requests[:-random.randint(0, n_loras)] + worker.model_runner.set_active_loras(iter_lora_requests, + LoRAMapping([], [])) + assert worker.list_loras().issuperset( + {lora_request.lora_int_id + for lora_request in iter_lora_requests}) diff --git a/tests/lora/utils.py b/tests/lora/utils.py new file mode 100644 index 0000000000000..280e0f2043e68 --- /dev/null +++ b/tests/lora/utils.py @@ -0,0 +1,88 @@ +from typing import List, Optional + +import torch + +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights + + +class DummyLoRAManager: + + def __init__(self): + super().__init__() + self._loras = {} + + def set_module_lora(self, module_name: str, lora: LoRALayerWeights): + self._loras[module_name] = lora + + def get_module_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + return self._loras.get(module_name, None) + + def init_random_lora(self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([weight.shape[1], rank], + dtype=weight.dtype, + device="cuda"), + lora_b=torch.rand([rank, weight.shape[0]], + dtype=weight.dtype, + device="cuda"), + ) + if generate_embeddings_tensor: + lora.embeddings_tensor = torch.rand(5, + generate_embeddings_tensor, + dtype=weight.dtype, + device="cuda") + self.set_module_lora(module_name, lora) + + return lora + + def init_lora(self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None): + lora = LoRALayerWeights( + module_name, + rank=rank, + lora_alpha=1, + lora_a=torch.rand([input_dim, rank], device="cuda"), + lora_b=torch.rand([rank, output_dim], device="cuda"), + embeddings_tensor=embeddings_tensor, + ) + self.set_module_lora(module_name, lora) + return lora + + def reset_lora(self): + self._loras = {} + + def init_packed_lora( + self, + module_name: str, + input_dim: int, + output_dims: List[int], + noop_lora_index: List[int] = None, + rank=8, + ): + base_loras = [] + noop_lora_index = set(noop_lora_index or []) + + for i, out_dim in enumerate(output_dims): + base_lora = self.init_lora( + module_name + "_000_" + str(i), + input_dim, + out_dim, + rank=rank, + noop=i in noop_lora_index, + ) + base_loras.append(base_lora) + packed_lora = PackedLoRALayerWeights.pack(base_loras) + self.set_module_lora(module_name, packed_lora) + return packed_lora diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index bcd0cd60bfc52..962183a29fbfa 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -19,10 +19,11 @@ def __init__(self, vocab_size: int, fake_logits: torch.Tensor): self.fake_logits = fake_logits def forward(self, *args, **kwargs): - with patch("vllm.model_executor.layers.sampler._prune_hidden_states", - lambda x, y: x), patch( - "vllm.model_executor.layers.sampler._get_logits", - lambda *args, **kwargs: self.fake_logits): + with patch( + "vllm.model_executor.layers.sampler._prune_hidden_states", + lambda x, y: x), patch( + "vllm.model_executor.layers.sampler.Sampler._get_logits", + lambda *args, **kwargs: self.fake_logits): return super().forward(*args, **kwargs) @@ -38,7 +39,7 @@ def _prepare_test( device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) - model_runner = ModelRunner(None, None, None) + model_runner = ModelRunner(None, None, None, None) return input_tensor, fake_logits, sampler, model_runner @@ -266,7 +267,7 @@ def test_sampler_top_k_top_p(seed: int): device=input_tensor.device, dtype=input_tensor.dtype) sampler = MockLogitsSampler(32000, fake_logits) - model_runner = ModelRunner(None, None, None) + model_runner = ModelRunner(None, None, None, None) generation_model = GenerationMixin() generation_config = GenerationConfig(top_k=top_k, diff --git a/tests/worker/spec_decode/utils.py b/tests/worker/spec_decode/utils.py index 812033829394f..e0db770046ec8 100644 --- a/tests/worker/spec_decode/utils.py +++ b/tests/worker/spec_decode/utils.py @@ -83,8 +83,8 @@ def create_worker(cls: type, enforce_eager=enforce_eager, ) - (model_config, cache_config, parallel_config, - scheduler_config) = engine_args.create_engine_configs() + (model_config, cache_config, parallel_config, scheduler_config, + _) = engine_args.create_engine_configs() distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index edbe10684741f..5d9ad0520de13 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -6,7 +6,7 @@ def test_prepare_prompt(): - model_runner = ModelRunner(None, None, None) + model_runner = ModelRunner(None, None, None, None) model_runner.set_block_size(16) batch_size = random.randint(1, 256) @@ -33,7 +33,7 @@ def test_prepare_prompt(): expected_selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_seq_len - input_tokens, input_positions, _, return_prompt_lens, _ = ( + input_tokens, input_positions, _, return_prompt_lens, _, _, _, _ = ( model_runner._prepare_prompt(seq_group_metadata_list)) assert return_prompt_lens == prompt_lens sampling_metadata = model_runner._prepare_sample(seq_group_metadata_list, diff --git a/vllm/config.py b/vllm/config.py index f1efcc66e9097..8acd15a3b7d9a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,4 +1,5 @@ -from typing import Optional, Union +from typing import Optional, Union, ClassVar +from dataclasses import dataclass import os import torch @@ -397,6 +398,54 @@ def _verify_args(self) -> None: f"({self.max_num_seqs}).") +@dataclass +class LoRAConfig: + max_lora_rank: int + max_loras: int + max_cpu_loras: Optional[int] = None + lora_dtype: Optional[torch.dtype] = None + lora_extra_vocab_size: int = 256 + # This is a constant. + lora_vocab_padding_size: ClassVar[int] = 256 + + def __post_init__(self): + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + possible_max_ranks = (8, 16, 32, 64) + possible_lora_extra_vocab_size = (0, 256, 512) + if self.max_lora_rank not in possible_max_ranks: + raise ValueError( + f"max_lora_rank ({self.max_lora_rank}) must be one of " + f"{possible_max_ranks}.") + if self.lora_extra_vocab_size not in possible_lora_extra_vocab_size: + raise ValueError( + f"lora_extra_vocab_size ({self.lora_extra_vocab_size}) " + f"must be one of {possible_lora_extra_vocab_size}.") + if self.max_loras < 1: + raise ValueError(f"max_loras ({self.max_loras}) must be >= 1.") + if self.max_cpu_loras is None: + self.max_cpu_loras = self.max_loras + elif self.max_cpu_loras < self.max_loras: + raise ValueError( + f"max_cpu_loras ({self.max_cpu_loras}) must be >= " + f"max_num_seqs ({self.max_loras})") + + def verify_with_model_config(self, model_config: ModelConfig): + if self.lora_dtype in (None, "auto"): + self.lora_dtype = model_config.dtype + elif isinstance(self.lora_dtype, str): + self.lora_dtype = getattr(torch, self.lora_dtype) + if model_config.quantization is not None: + raise ValueError( + "LoRA is not supported with quantized models yet.") + + def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig): + if scheduler_config.max_num_batched_tokens > 65528: + raise ValueError( + "Due to limitations of the custom LoRA CUDA kernel, " + "max_num_batched_tokens must be <= 65528 when " + "LoRA is enabled.") + + _STR_DTYPE_TO_TORCH_DTYPE = { "half": torch.float16, "float16": torch.float16, diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 69e3d5993c37f..4fdf9ec341cfd 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1,11 +1,12 @@ from collections import deque import enum import time -from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union +from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union, Set -from vllm.config import CacheConfig, SchedulerConfig +from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory +from vllm.lora.request import LoRARequest from vllm.logger import init_logger from vllm.sequence import (Sequence, SequenceData, SequenceGroup, SequenceGroupMetadata, SequenceStatus) @@ -49,11 +50,25 @@ def __init__( assert not (blocks_to_swap_in and blocks_to_swap_out) self.ignored_seq_groups = ignored_seq_groups + self.num_loras = len(self.lora_requests) + if self.num_loras > 0: + self._sort_by_lora_ids() + def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. return (not self.scheduled_seq_groups and not self.blocks_to_swap_in and not self.blocks_to_swap_out and not self.blocks_to_copy) + def _sort_by_lora_ids(self) -> bool: + self.scheduled_seq_groups = sorted( + self.scheduled_seq_groups, + key=lambda g: (g.lora_request.lora_int_id + if g.lora_request else 0, g.request_id)) + + @property + def lora_requests(self) -> Set[LoRARequest]: + return {g.lora_request for g in self.scheduled_seq_groups} + class Scheduler: @@ -61,9 +76,14 @@ def __init__( self, scheduler_config: SchedulerConfig, cache_config: CacheConfig, + lora_config: Optional[LoRAConfig], ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config + # Note for LoRA scheduling: the current policy is extremely + # simple and NOT fair. It can lead to starvation of some + # LoRAs. This should be improved in the future. + self.lora_config = lora_config self.prompt_limit = min(self.scheduler_config.max_model_len, self.scheduler_config.max_num_batched_tokens) @@ -87,6 +107,10 @@ def __init__( # Sequence groups in the SWAPPED state. self.swapped: Deque[SequenceGroup] = deque() + @property + def lora_enabled(self) -> bool: + return bool(self.lora_config) + def add_seq_group(self, seq_group: SequenceGroup) -> None: # Add sequence groups to the waiting queue. self.waiting.append(seq_group) @@ -150,14 +174,17 @@ def _schedule(self) -> SchedulerOutputs: # requests in the generation phase. num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted # sequence groups are added to the front and the new sequence groups # are added to the back. + leftover_waiting_sequences = deque() while self.waiting: seq_group = self.waiting[0] - waiting_seqs = seq_group.get_seqs( status=SequenceStatus.WAITING) assert len(waiting_seqs) == 1, ( @@ -188,6 +215,17 @@ def _schedule(self) -> SchedulerOutputs: self.waiting.popleft() continue + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_waiting_sequences.appendleft(seq_group) + self.waiting.popleft() + continue + # If the number of batched tokens exceeds the limit, stop. new_seq_lens = seq_lens + [num_prompt_tokens] num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) @@ -207,12 +245,16 @@ def _schedule(self) -> SchedulerOutputs: break seq_lens = new_seq_lens - seq_group = self.waiting.popleft() + if lora_int_id > 0: + curr_loras.add(lora_int_id) + self.waiting.popleft() self._allocate(seq_group) self.running.append(seq_group) num_curr_seqs += num_new_seqs scheduled.append(seq_group) + self.waiting.extendleft(leftover_waiting_sequences) + if scheduled or ignored_seq_groups: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, @@ -260,9 +302,25 @@ def _schedule(self) -> SchedulerOutputs: if not preempted: num_curr_seqs = sum(seq_group.get_max_num_running_seqs() for seq_group in self.running) + curr_loras = set( + seq_group.lora_int_id + for seq_group in self.running) if self.lora_enabled else None + + leftover_swapped = deque() while self.swapped: seq_group = self.swapped[0] + lora_int_id = 0 + if self.lora_enabled: + lora_int_id = seq_group.lora_int_id + if lora_int_id > 0 and lora_int_id not in curr_loras and len( + curr_loras) >= self.lora_config.max_loras: + # We don't have a space for another LoRA, so + # we ignore this request for now. + leftover_swapped.appendleft(seq_group) + self.swapped.popleft() + continue + # If the sequence group cannot be swapped in, stop. if not self.block_manager.can_swap_in(seq_group): break @@ -274,12 +332,16 @@ def _schedule(self) -> SchedulerOutputs: self.scheduler_config.max_num_seqs): break - seq_group = self.swapped.popleft() + if lora_int_id > 0: + curr_loras.add(lora_int_id) + self.swapped.popleft() self._swap_in(seq_group, blocks_to_swap_in) self._append_slot(seq_group, blocks_to_copy) num_curr_seqs += num_new_seqs self.running.append(seq_group) + self.swapped.extendleft(leftover_swapped) + # Each sequence in the generation phase only takes one token slot. # Therefore, the number of batched tokens is equal to the number of # sequences in the RUNNING state. @@ -320,6 +382,7 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: seq_data=seq_data, sampling_params=seq_group.sampling_params, block_tables=block_tables, + lora_request=seq_group.lora_request, prefix=seq_group.prefix, ) seq_group_metadata_list.append(seq_group_metadata) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 7e58069e2c22d..090fa95bcac02 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -4,7 +4,7 @@ from typing import Optional, Tuple from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) @dataclass @@ -35,6 +35,12 @@ class EngineArgs: quantization: Optional[str] = None enforce_eager: bool = False max_context_len_to_capture: int = 8192 + enable_lora: bool = False + max_loras: int = 1 + max_lora_rank: int = 16 + lora_extra_vocab_size: int = 256 + lora_dtype = 'auto' + max_cpu_loras: Optional[int] = None def __post_init__(self): if self.tokenizer is None: @@ -202,6 +208,39 @@ def add_cli_args( help='maximum context length covered by CUDA ' 'graphs. When a sequence has context length ' 'larger than this, we fall back to eager mode.') + # LoRA related configs + parser.add_argument('--enable-lora', + action='store_true', + help='If True, enable handling of LoRA adapters.') + parser.add_argument('--max-loras', + type=int, + default=EngineArgs.max_loras, + help='Max number of LoRAs in a single batch.') + parser.add_argument('--max-lora-rank', + type=int, + default=EngineArgs.max_lora_rank, + help='Max LoRA rank.') + parser.add_argument( + '--lora-extra-vocab-size', + type=int, + default=EngineArgs.lora_extra_vocab_size, + help=('Maximum size of extra vocabulary that can be ' + 'present in a LoRA adapter (added to the base ' + 'model vocabulary).')) + parser.add_argument( + '--lora-dtype', + type=str, + default=EngineArgs.lora_dtype, + choices=['auto', 'float16', 'bfloat16', 'float32'], + help=('Data type for LoRA. If auto, will default to ' + 'base model dtype.')) + parser.add_argument( + '--max-cpu-loras', + type=int, + default=EngineArgs.max_cpu_loras, + help=('Maximum number of LoRAs to store in CPU memory. ' + 'Must be >= than max_num_seqs. ' + 'Defaults to max_num_seqs.')) return parser @classmethod @@ -214,7 +253,8 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, - ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: + ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig, + Optional[LoRAConfig]]: model_config = ModelConfig(self.model, self.tokenizer, self.tokenizer_mode, self.trust_remote_code, self.download_dir, self.load_format, @@ -234,7 +274,14 @@ def create_engine_configs( self.max_num_seqs, model_config.max_model_len, self.max_paddings) - return model_config, cache_config, parallel_config, scheduler_config + lora_config = LoRAConfig( + max_lora_rank=self.max_lora_rank, + max_loras=self.max_loras, + lora_extra_vocab_size=self.lora_extra_vocab_size, + lora_dtype=self.lora_dtype, + max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras + and self.max_cpu_loras > 0 else None) if self.enable_lora else None + return model_config, cache_config, parallel_config, scheduler_config, lora_config @dataclass diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 11ae0d2025c63..3d009c7818548 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -4,6 +4,7 @@ from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union, AsyncIterator) +from vllm.lora.request import LoRARequest from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine @@ -206,6 +207,52 @@ async def step_async(self) -> List[RequestOutput]: scheduler_outputs, ts_start=ts_start) + async def encode_request_async( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = await self.tokenizer.encode_async( + request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + + async def add_request_async( + self, + request_id: str, + prompt: Optional[str], + sampling_params: SamplingParams, + prompt_token_ids: Optional[List[int]] = None, + arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, + prefix_pos: Optional[int] = None, + ) -> None: + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + + return self.add_request( + request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + sampling_params=sampling_params, + arrival_time=arrival_time, + lora_request=lora_request, + prefix_pos=prefix_pos, + ) + async def _run_workers_async( self, method: str, @@ -335,7 +382,7 @@ async def engine_step(self) -> bool: if self.engine_use_ray: await self.engine.add_request.remote(**new_request) else: - self.engine.add_request(**new_request) + await self.engine.add_request_async(**new_request) if finished_requests: await self._engine_abort(finished_requests) @@ -374,6 +421,7 @@ async def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, prefix_pos: Optional[int] = None, ) -> AsyncStream: if self.log_requests: @@ -389,7 +437,8 @@ async def add_request( f"prompt: {shortened_prompt!r}, " f"prefix_pos: {prefix_pos}," f"sampling params: {sampling_params}, " - f"prompt token ids: {shortened_token_ids}.") + f"prompt token ids: {shortened_token_ids}, " + f"lora_request: {lora_request}.") if not self.is_running: if self.start_engine_loop: @@ -401,12 +450,21 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") + if arrival_time is None: + arrival_time = time.time() + prompt_token_ids = await self.engine.encode_request_async( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) + stream = self._request_tracker.add_request( request_id, prompt=prompt, sampling_params=sampling_params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time, + lora_request=lora_request, prefix_pos=prefix_pos) return stream @@ -417,6 +475,7 @@ async def generate( sampling_params: SamplingParams, request_id: str, prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, prefix_pos: Optional[int] = None, ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. @@ -432,6 +491,7 @@ async def generate( request_id: The unique id of the request. prompt_token_ids: The token IDs of the prompt. If None, we use the tokenizer to convert the prompts to token IDs. + lora_request: LoRA request to use for generation, if any. prefix_pos: If not None, we use the given position as the prefix position for each prompt. We will cache the prefix's KV cache and reuse it for the next request with the same prefix. @@ -490,12 +550,15 @@ async def generate( arrival_time = time.monotonic() try: - stream = await self.add_request(request_id, - prompt, - sampling_params, - prompt_token_ids=prompt_token_ids, - arrival_time=arrival_time, - prefix_pos=prefix_pos) + stream = await self.add_request( + request_id, + prompt, + sampling_params, + prompt_token_ids=prompt_token_ids, + arrival_time=arrival_time, + lora_request=lora_request, + prefix_pos=prefix_pos, + ) async for request_output in stream: yield request_output diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 18890bbe81d61..df15dd753407d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -5,8 +5,9 @@ from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union) +from vllm.lora.request import LoRARequest from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import record_metrics @@ -17,7 +18,7 @@ from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) + TokenizerGroup) from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port, get_distributed_init_method if ray: @@ -64,6 +65,7 @@ def __init__( cache_config: CacheConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], placement_group: Optional["PlacementGroup"], log_stats: bool, ) -> None: @@ -87,17 +89,13 @@ def __init__( self.model_config = model_config self.cache_config = cache_config + self.lora_config = lora_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.log_stats = log_stats self._verify_args() - self.tokenizer = get_tokenizer( - model_config.tokenizer, - tokenizer_mode=model_config.tokenizer_mode, - trust_remote_code=model_config.trust_remote_code, - tokenizer_revision=model_config.tokenizer_revision, - revision=model_config.revision) + self._init_tokenizer() self.seq_counter = Counter() # Create the parallel GPU workers. @@ -114,7 +112,7 @@ def __init__( self._init_cache() # Create the scheduler. - self.scheduler = Scheduler(scheduler_config, cache_config) + self.scheduler = Scheduler(scheduler_config, cache_config, lora_config) # Logging. self.last_logging_time = 0.0 @@ -127,6 +125,9 @@ def __init__( # LIst of (timestamp, delta_time) self.prompt_times_ms: List[Tuple[float, float]] = [] + def get_tokenizer_for_seq(self, sequence: Sequence): + return self.tokenizer.get_lora_tokenizer(sequence.lora_request) + def _init_workers(self): # Lazy import the Worker to avoid importing torch.cuda/xformers # before CUDA_VISIBLE_DEVICES is set in the Worker @@ -145,11 +146,24 @@ def _init_workers(self): local_rank=0, rank=0, distributed_init_method=distributed_init_method, + lora_config=self.lora_config, is_driver_worker=True, ) self._run_workers("init_model") self._run_workers("load_model") + def _init_tokenizer(self, **tokenizer_init_kwargs): + init_kwargs = dict( + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + tokenizer_mode=self.model_config.tokenizer_mode, + trust_remote_code=self.model_config.trust_remote_code, + revision=self.model_config.tokenizer_revision) + init_kwargs.update(tokenizer_init_kwargs) + self.tokenizer: TokenizerGroup = TokenizerGroup( + self.model_config.tokenizer, **init_kwargs) + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if self.parallel_config.tensor_parallel_size == 1: @@ -237,6 +251,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank, rank, distributed_init_method, + lora_config=self.lora_config, )) driver_rank = 0 @@ -248,6 +263,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", driver_local_rank, driver_rank, distributed_init_method, + lora_config=self.lora_config, is_driver_worker=True, ) @@ -261,6 +277,10 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", def _verify_args(self) -> None: self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) + if self.lora_config: + self.lora_config.verify_with_model_config(self.model_config) + self.lora_config.verify_with_scheduler_config( + self.scheduler_config) def _init_cache(self) -> None: """Profiles the memory usage and initializes the KV cache. @@ -336,6 +356,20 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": log_stats=not engine_args.disable_log_stats) return engine + def encode_request( + self, + request_id: str, # pylint: disable=unused-argument + prompt: Optional[str], + prompt_token_ids: Optional[List[int]] = None, + lora_request: Optional[LoRARequest] = None, + ): + if prompt_token_ids is None: + assert prompt is not None + prompt_token_ids = self.tokenizer.encode(request_id=request_id, + prompt=prompt, + lora_request=lora_request) + return prompt_token_ids + def add_request( self, request_id: str, @@ -343,6 +377,7 @@ def add_request( sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, + lora_request: Optional[LoRARequest] = None, prefix_pos: Optional[int] = None, ) -> None: """Add a request to the engine's request pool. @@ -390,24 +425,31 @@ def add_request( >>> # continue the request processing >>> ... """ + if lora_request is not None and not self.lora_config: + raise ValueError(f"Got lora_request {lora_request} but LoRA is " + "not enabled!") if arrival_time is None: arrival_time = time.monotonic() - if prompt_token_ids is None: - assert prompt is not None - prompt_token_ids = self.tokenizer.encode(prompt) + prompt_token_ids = self.encode_request( + request_id=request_id, + prompt=prompt, + prompt_token_ids=prompt_token_ids, + lora_request=lora_request) # Create the sequences. block_size = self.cache_config.block_size seq_id = next(self.seq_counter) - seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) + seq = Sequence(seq_id, prompt, prompt_token_ids, block_size, + lora_request) # Check whether the input specifies prefix prefix = self.scheduler.prefix_pool.add_or_get_prefix( - prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None + prompt_token_ids[:prefix_pos], lora_request.lora_int_id + if lora_request else 0) if prefix_pos is not None else None # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time, prefix) + arrival_time, lora_request, prefix) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -457,11 +499,13 @@ def _check_beam_search_early_stopping( current_worst_score = (current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + current_worst_seq).eos_token_id)) if early_stopping is False: highest_attainable_score = (best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -475,7 +519,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id, seq_len=max_possible_length)) else: # Otherwise, beam search will prefer shorter sequences. The @@ -484,7 +529,8 @@ def _check_beam_search_early_stopping( highest_attainable_score = ( best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.get_tokenizer_for_seq( + best_running_seq).eos_token_id)) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -575,7 +621,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the finished sequences by their scores. all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: @@ -603,7 +649,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Sort the running sequences by their scores. running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), + eos_token_id=self.get_tokenizer_for_seq(x[0]).eos_token_id), reverse=True) # Check if we can stop the beam search. @@ -885,7 +931,7 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" (new_tokens, new_output_text, prefix_offset, read_offset) = detokenize_incrementally( - self.tokenizer, + self.get_tokenizer_for_seq(seq), all_input_ids=seq.get_token_ids(), prev_tokens=seq.tokens, prefix_offset=seq.prefix_offset, @@ -927,11 +973,28 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == self.tokenizer.eos_token_id): + if ((not sampling_params.ignore_eos) and seq.get_last_token_id() + == self.get_tokenizer_for_seq(seq).eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> List[int]: + return self._run_workers("list_loras") + def _run_workers( self, method: str, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index b819e233c06b2..aab0c9615f411 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -3,6 +3,7 @@ from tqdm import tqdm from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from vllm.lora.request import LoRARequest from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.outputs import RequestOutput @@ -122,6 +123,7 @@ def generate( prompt_token_ids: Optional[List[List[int]]] = None, prefix_pos: Optional[Union[int, List[int]]] = None, use_tqdm: bool = True, + lora_request: Optional[LoRARequest] = None, ) -> List[RequestOutput]: """Generates the completions for the input prompts. @@ -141,6 +143,7 @@ def generate( This is an experimental feature, and may be replaced with automatic prefix caching in the future. use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. Returns: A list of `RequestOutput` objects containing the generated @@ -168,7 +171,11 @@ def generate( prefix_pos_i = prefix_pos[i] if prefix_pos is not None else None token_ids = None if prompt_token_ids is None else prompt_token_ids[ i] - self._add_request(prompt, sampling_params, token_ids, prefix_pos_i) + self._add_request(prompt, + sampling_params, + token_ids, + lora_request=lora_request, + prefix_pos=prefix_pos_i) return self._run_engine(use_tqdm) def _add_request( @@ -176,6 +183,7 @@ def _add_request( prompt: Optional[str], sampling_params: SamplingParams, prompt_token_ids: Optional[List[int]], + lora_request: Optional[LoRARequest] = None, prefix_pos: Optional[int] = None, ) -> None: request_id = str(next(self.request_counter)) @@ -183,6 +191,7 @@ def _add_request( prompt, sampling_params, prompt_token_ids, + lora_request=lora_request, prefix_pos=prefix_pos) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: diff --git a/vllm/lora/__init__.py b/vllm/lora/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py new file mode 100644 index 0000000000000..e1aac20b038b4 --- /dev/null +++ b/vllm/lora/layers.py @@ -0,0 +1,975 @@ +# pylint: disable=unused-argument +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, List, Optional, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.lora.punica import add_lora, add_lora_slice, bgmv +from vllm.model_executor.layers.sampler import Sampler +from vllm.model_executor.parallel_utils.communication_op import ( + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, + tensor_model_parallel_gather, +) +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear, + QKVParallelLinear, + MergedColumnParallelLinear) +from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding, ParallelLMHead +from vllm.model_executor.parallel_utils.parallel_state import ( + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.model_executor.parallel_utils.utils import split_tensor_along_last_dim + +if TYPE_CHECKING: + pass + + +def _apply_lora( + x: torch.Tensor, + lora_a_stacked: torch.Tensor, + lora_b_stacked: torch.Tensor, + indices: torch.Tensor, + output: torch.Tensor, +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: (num_loras, lora_rank, hidden_dim) + lora_b_stacked: (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0) + return output.view_as(org_output) + + +def _apply_lora_packed_nslice( + x: torch.Tensor, + lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], +): + """Applies lora to each input. + + This method applies all loras to each input. It uses the + indices vector to determine which lora yields the + correct output. An index of -1 means no lora should be + applied. This method adds the final lora results to the + output. + + This method is used for layers that are composed of multiple sublayers + (slices) packed together. + + Input shapes: + x: (batch_size, hidden_dim) + lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim) + lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), where n is number of slices + """ + org_output = output + x = x.view(-1, x.shape[-1]) + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + offset_left = 0 + for slice_idx in range(len(output_slices)): + add_lora_slice(output, x, lora_a_stacked[slice_idx], + lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left, + output_slices[slice_idx]) + offset_left += output_slices[slice_idx] + return output.view_as(org_output) + + +@dataclass +class LoRAMapping: + # Per every token in input_ids: + index_mapping: Tuple[int, ...] + # Per sampled token: + prompt_mapping: Tuple[int, ...] + + def __post_init__(self): + self.index_mapping = tuple(self.index_mapping) + self.prompt_mapping = tuple(self.prompt_mapping) + + +class BaseLayerWithLoRA(nn.Module): + + def create_lora_weights(self, max_loras: int, lora_config: LoRAConfig, + model_config: PretrainedConfig) -> None: + """Initializes lora matrices.""" + ... + + def reset_lora(self, index: int): + """Resets the lora weights at index back to 0.""" + ... + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + """Overwrites lora tensors at index.""" + ... + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + """Sets the mapping indices.""" + ... + + +class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: VocabParallelEmbedding) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + + lora_vocab_start_idx = self.base_layer.org_vocab_size + weights_idx = None + if self.base_layer.vocab_end_index > lora_vocab_start_idx: + # We can start adding lora weights + weights_idx = max( + lora_vocab_start_idx - self.base_layer.vocab_start_index, 0) + self.embeddings_slice = (self.base_layer.vocab_start_index - + self.base_layer.org_vocab_size + + weights_idx, + self.base_layer.vocab_end_index - + self.base_layer.org_vocab_size) + self.embeddings_weights = self.base_layer.weight.data[weights_idx:] + self.embeddings_weights.fill_(0) + else: + self.embeddings_slice = None + self.embeddings_weights = None + + self.embeddings_tensors = torch.zeros( + ( + max_loras, + lora_config.lora_extra_vocab_size, + self.base_layer.embedding_dim, + ), + dtype=self.base_layer.weight.dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + self.base_layer.org_vocab_size + + lora_config.lora_extra_vocab_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.embedding_dim, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_a_stacked_2d = self.lora_a_stacked.view( + self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1], + self.lora_a_stacked.shape[2], + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.embeddings_indices = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, :lora_a.shape[0], :lora_a.shape[1]].copy_( + lora_a, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1]].copy_(embeddings_tensor, non_blocking=True) + if self.embeddings_slice is not None: + # TODO(yard1): Optimize this copy, we don't need to copy + # everything, just the modified part + embeddings = self.embeddings_tensors.view( + self.embeddings_tensors.shape[0] * + self.embeddings_tensors.shape[1], + self.embeddings_tensors.shape[2] + )[self.embeddings_slice[0]:self.embeddings_slice[1]] + self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.embeddings_indices = embeddings_indices + self.indices_len = indices_len + + def forward(self, x: torch.Tensor) -> torch.Tensor: + added_tokens_mask = x > self.base_layer.org_vocab_size - 1 + indices = self.embeddings_indices[1][:self.indices_len[3]].view_as(x) + full_lora_a_embeddings = F.embedding( + x + indices, + self.lora_a_stacked_2d, + ) + indices = self.embeddings_indices[0][:self.indices_len[3]].view_as(x) + full_output = self.base_layer.forward( + x.add_(indices * added_tokens_mask)) + + full_output_org = full_output + if full_output.ndim == 3: + full_output = full_output.view( + full_output.shape[0] * full_output.shape[1], -1) + if full_lora_a_embeddings.ndim == 3: + full_lora_a_embeddings = full_lora_a_embeddings.view( + full_lora_a_embeddings.shape[0] * + full_lora_a_embeddings.shape[1], -1) + bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + return full_output.view_as(full_output_org) + + +class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: ColumnParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + self.output_dim = self.lora_b_stacked.shape[1] + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of ColumnParallelLinear + + Args: + input_: Tensor whose last dimension is `input_size`. + + Returns: + - output + - bias + """ + bias = (self.base_layer.bias + if not self.base_layer.skip_bias_add else None) + + # Matrix multiply. + output_parallel = self.apply_weights(input_, bias) + if self.base_layer.gather_output: + # All-gather across the partitions. + output = tensor_model_parallel_all_gather(output_parallel) + else: + output = output_parallel + output_bias = (self.base_layer.bias + if self.base_layer.skip_bias_add else None) + return output, output_bias + + @property + def linear_weights(self): + return self.base_layer.linear_weights + + +class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 2 sublayers (slices) + packed together (eg. gate_proj + up_proj -> gate_up_proj). + + This means we have 2 LoRAs, each applied to one half of the layer. + + Both slices must have the same size. + """ + + def __init__(self, base_layer: MergedColumnParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + n_slices = 2 + if not (len(self.base_layer.output_sizes) == n_slices + and self.base_layer.output_sizes[0] + == self.base_layer.output_sizes[1]): + raise ValueError( + "LoRAColumnParallelLinear2Slice requires 2 slices with " + "the same size.") + self.tp_size = get_tensor_model_parallel_world_size() + + self.lora_a_stacked = tuple( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + self.lora_b_stacked = tuple( + torch.zeros( + max_loras, + 1, + self.base_layer.weight.shape[0] // 2, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) for _ in range(n_slices)) + + self.indices: Optional[torch.Tensor] = None + self.output_dim = self.lora_b_stacked[0].shape[2] + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_b_stacked[1][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[0][:, + start_idx:end_idx], lora_b[1][:, + start_idx:end_idx] + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_nslice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + (self.output_dim, self.output_dim), + ) + return output + + +class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): + """ColumnParallelLinear layer that is composed of 3 sublayers (slices) + packed together in qkv proj fashion + (q_proj + k_proj + v_proj -> qkv_proj). + + This means we have 3 LoRAs, each applied to one slice of the layer. + + Q slice may have different shape than K and V slices (which both have + the same shape). + """ + + def __init__(self, base_layer: QKVParallelLinear) -> None: + super().__init__(base_layer) + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + self.q_proj_shard_size = (self.base_layer.num_heads * + self.base_layer.head_size) + self.kv_proj_shard_size = (self.base_layer.num_kv_heads * + self.base_layer.head_size) + self.q_shard_id = tp_rank + self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + + # q, k, v + self.lora_a_stacked = ( + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + ) + self.lora_b_stacked = ( + torch.zeros( + max_loras, + 1, + self.q_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + torch.zeros( + max_loras, + 1, + self.kv_proj_shard_size, + lora_config.max_lora_rank, + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ), + ) + + self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size, + self.kv_proj_shard_size) + self.packed_indices: Optional[torch.Tensor] = None + self.standard_indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[0][index] = 0 + self.lora_b_stacked[0][index] = 0 + self.lora_a_stacked[1][index] = 0 + self.lora_b_stacked[1][index] = 0 + self.lora_a_stacked[2][index] = 0 + self.lora_b_stacked[2][index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + + if self.tp_size > 1: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) + else: + if lora_b[0] is not None: + self.lora_b_stacked[0][ + index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( + lora_b[0].T, non_blocking=True) + if lora_b[1] is not None: + self.lora_b_stacked[1][ + index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( + lora_b[1].T, non_blocking=True) + if lora_b[2] is not None: + self.lora_b_stacked[2][ + index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( + lora_b[2].T, non_blocking=True) + + if lora_a[0] is not None: + self.lora_a_stacked[0][ + index, 0, :lora_a[0].shape[1], :lora_a[0].shape[0]].copy_( + lora_a[0].T, non_blocking=True) + if lora_a[1] is not None: + self.lora_a_stacked[1][ + index, 0, :lora_a[1].shape[1], :lora_a[1].shape[0]].copy_( + lora_a[1].T, non_blocking=True) + if lora_a[2] is not None: + self.lora_a_stacked[2][ + index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( + lora_a[2].T, non_blocking=True) + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x, bias) + _apply_lora_packed_nslice( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + self.output_slices, + ) + return output + + +class RowParallelLinearWithLoRA(BaseLayerWithLoRA): + + def __init__(self, base_layer: RowParallelLinear) -> None: + super().__init__() + self.base_layer = base_layer + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.base_layer.weight.shape[1], + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + self.base_layer.weight.shape[0], + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.base_layer.weight.device, + ) + self.indices: Optional[torch.Tensor] = None + self.indices_len: Optional[List[int]] = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + if self.base_layer.tp_size > 1: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.base_layer.weight.shape[1] + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = base_indices + self.indices_len = indices_len + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer.linear_weights, x) + _apply_lora( + x, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[0]], + output, + ) + return output + + def forward(self, input_): + """Forward of RowParallelLinear + + Args: + input_: tensor whose last dimension is `input_size`. If + `input_is_parallel` is set, then the last dimension + is `input_size // tp_size`. + + Returns: + - output + - bias + """ + # Set up backprop all-reduce. + if self.base_layer.input_is_parallel: + input_parallel = input_ + else: + # TODO: simplify code below + tp_rank = get_tensor_model_parallel_rank() + splitted_input = split_tensor_along_last_dim( + input_, num_partitions=self.base_layer.tp_size) + input_parallel = splitted_input[tp_rank].contiguous() + + # Matrix multiply. + output_parallel = self.apply_weights(input_parallel) + if self.base_layer.reduce_results and self.base_layer.tp_size > 1: + output_ = tensor_model_parallel_all_reduce(output_parallel) + else: + output_ = output_parallel + + if not self.base_layer.skip_bias_add: + output = (output_ + self.base_layer.bias + if self.base_layer.bias is not None else output_) + output_bias = None + else: + output = output_ + output_bias = self.base_layer.bias + return output, output_bias + + @property + def weight(self): + return self.base_layer.weight + + +class SamplerWithLoRA(BaseLayerWithLoRA): + + def __init__( + self, + base_layer: Sampler, + hidden_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> None: + super().__init__() + self.base_layer = base_layer + self.hidden_size = hidden_size + self.dtype = dtype + self.device = device + + @property + def vocab_size(self): + return self.base_layer.vocab_size + + @property + def org_vocab_size(self): + return self.base_layer.org_vocab_size + + @property + def include_gpu_probs_tensor(self): + return self.base_layer.include_gpu_probs_tensor + + def create_lora_weights( + self, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, + ) -> None: + # Keep this in sync with csrc/punica/bgmv/bgmv_config.h + if 32000 < self.base_layer.vocab_size > 33024: + raise ValueError( + "When using LoRA, vocab size must be 32000 >= vocab_size <= 33024" + ) + self.lora_a_stacked = torch.zeros( + ( + max_loras, + 1, + lora_config.max_lora_rank, + self.hidden_size, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.lora_b_stacked = torch.zeros( + ( + max_loras, + 1, + # Pad for kernel compatibility + math.ceil(self.base_layer.vocab_size / + lora_config.lora_vocab_padding_size) * + lora_config.lora_vocab_padding_size, + lora_config.max_lora_rank, + ), + dtype=lora_config.lora_dtype, + device=self.device, + ) + self.embeddings_tensors = torch.full( + (max_loras, lora_config.lora_extra_vocab_size, self.hidden_size), + fill_value=float("-inf"), + dtype=self.dtype, + device=self.device, + ) + self.indices = None + self.indices_padded = None + self.indices_len = None + + def reset_lora(self, index: int): + self.lora_a_stacked[index] = 0 + self.lora_b_stacked[index] = 0 + self.embeddings_tensors[index] = float("-inf") + + def set_lora( + self, + index: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor], + ): + self.reset_lora(index) + self.lora_a_stacked[index, + 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( + lora_a.T, non_blocking=True) + self.lora_b_stacked[index, + 0, :lora_b.shape[1], :lora_b.shape[0]].copy_( + lora_b.T, non_blocking=True) + if embeddings_tensor is not None: + self.embeddings_tensors[ + index, :embeddings_tensor.shape[0], :embeddings_tensor. + shape[1], ] = embeddings_tensor + + def set_mapping( + self, + base_indices: torch.Tensor, + sampler_indices: torch.Tensor, + sampler_indices_padded: torch.Tensor, + embeddings_indices: torch.Tensor, + indices_len: List[int], + ): + self.indices = sampler_indices + self.indices_padded = sampler_indices_padded + self.indices_len = indices_len + + def _get_logits( + self, + hidden_states: torch.Tensor, + embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + if logits is None: + return None + + lora_logits = torch.empty( + self.embeddings_tensors.shape[0] + 1, + self.embeddings_tensors.shape[1], + hidden_states.shape[0], + dtype=self.embeddings_tensors.dtype, + device=self.embeddings_tensors.device, + ) + torch.matmul(self.embeddings_tensors, + hidden_states.T, + out=lora_logits[:-1]) + lora_logits[-1] = float("-inf") + lora_logits = lora_logits.mT + lora_logits = (lora_logits.reshape( + lora_logits.shape[0] * lora_logits.shape[1], + lora_logits.shape[2], + ).index_select(0, + self.indices_padded[:self.indices_len[2]]).nan_to_num_( + nan=float("-inf"), + posinf=float("inf"), + neginf=float("-inf"))) + logits[:, + self.base_layer.org_vocab_size:self.base_layer.org_vocab_size + + lora_logits.shape[1]] = lora_logits + + _apply_lora( + hidden_states, + self.lora_a_stacked, + self.lora_b_stacked, + self.indices[:self.indices_len[1]], + logits, + ) + + # Remove paddings in vocab (if any). + logits = logits[:, :self.base_layer.vocab_size] + + return logits + + def forward(self, *args, **kwargs): + return type(self.base_layer).forward(self, *args, **kwargs) + + +def from_layer( + layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None) -> BaseLayerWithLoRA: + supported_layer_types = { + VocabParallelEmbedding: VocabParallelEmbeddingWithLoRA, + ColumnParallelLinear: ColumnParallelLinearWithLoRA, + QKVParallelLinear: QKVParallelLinearWithLora, + MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA, + RowParallelLinear: RowParallelLinearWithLoRA, + } + for src_layer_type, lora_layer_type in supported_layer_types.items(): + if type(layer) is src_layer_type: # pylint: disable=unidiomatic-typecheck + ret = lora_layer_type(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_sampler( + layer: Sampler, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> SamplerWithLoRA: + ret = SamplerWithLoRA(layer, lm_head.embedding_dim, lm_head.weight.dtype, + lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret diff --git a/vllm/lora/lora.py b/vllm/lora/lora.py new file mode 100644 index 0000000000000..fbb228c9582d4 --- /dev/null +++ b/vllm/lora/lora.py @@ -0,0 +1,160 @@ +from typing import List, Optional + +import torch +from vllm.utils import in_wsl + + +class LoRALayerWeights: + """LoRA weights for a layer composed of two low rank matrixes.""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alpha: int, + lora_a: torch.Tensor, + lora_b: torch.Tensor, + embeddings_tensor: Optional[torch.Tensor] = None, + scaling: Optional[float] = None, + ) -> None: + self.module_name = module_name + self.rank = rank + self.lora_alpha = lora_alpha + self.lora_a = lora_a + self.lora_b = lora_b + self.embeddings_tensor = embeddings_tensor + + if scaling is None: + self.scaling = self.lora_alpha / self.rank + else: + self.scaling = scaling + + def optimize(self) -> "LoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + if self.scaling == 1: + return + self.lora_b *= self.scaling + self.scaling = 1 + return self + + @property + def input_dim(self) -> int: + return self.lora_a.shape[0] + + @property + def output_dim(self) -> int: + return self.lora_b.shape[1] + + @property + def is_packed(self) -> bool: + return False + + @property + def extra_vocab_size(self) -> int: + return self.embeddings_tensor.shape[ + 0] if self.embeddings_tensor is not None else 0 + + @classmethod + def create_dummy_lora_weights( + cls, + module_name: str, + input_dim: int, + output_dim: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + embeddings_tensor_dim: Optional[int] = None) -> "LoRALayerWeights": + pin_memory = str(device) == "cpu" and not in_wsl() + lora_a = torch.zeros([input_dim, rank], + dtype=dtype, + device=device, + pin_memory=pin_memory) + lora_b = torch.zeros([rank, output_dim], + dtype=dtype, + device=device, + pin_memory=pin_memory) + embeddings_tensor = torch.rand( + 10, + embeddings_tensor_dim, + dtype=dtype, + device=device, + pin_memory=pin_memory) if embeddings_tensor_dim else None + return cls( + module_name, + rank=rank, + lora_alpha=1, + lora_a=lora_a, + lora_b=lora_b, + embeddings_tensor=embeddings_tensor, + ) + + +class PackedLoRALayerWeights(LoRALayerWeights): + """LoRA used for packed layers (eg. qkv_proj).""" + + def __init__( + self, + module_name: str, + rank: int, + lora_alphas: List[int], + lora_a: List[torch.Tensor], + lora_b: List[torch.Tensor], + scaling: Optional[List[float]] = None, + ) -> None: + super().__init__( + module_name=module_name, + rank=rank, + lora_alpha=0, + lora_a=lora_a, + lora_b=lora_b, + scaling=scaling, + embeddings_tensor=None, + ) + self.lora_alphas = lora_alphas + if scaling is None: + self.scaling = [ + lora_alpha / self.rank for lora_alpha in self.lora_alphas + ] + + @classmethod + def pack(cls, loras: List["LoRALayerWeights"]) -> "PackedLoRALayerWeights": + """Pack a list of LoRAs into a single LoRA. + + If LoRA is None, it signifies that the submodule does not have a LoRA. + """ + first_lora = next(lora for lora in loras if lora is not None) + for lora in loras: + if lora is None: + continue + lora.optimize() + rank = first_lora.rank + module_name = first_lora.module_name + obj = cls( + module_name, + rank, + [lora.lora_alpha if lora is not None else None for lora in loras], + [lora.lora_a if lora is not None else None for lora in loras], + [lora.lora_b if lora is not None else None for lora in loras], + scaling=[1 if lora is not None else None for lora in loras]) + return obj + + def optimize(self) -> "PackedLoRALayerWeights": + """Optimize the LoRA by merging the scaling into lora_b.""" + for i in range(len(self.lora_b)): + if self.scaling[i] == 1 or self.lora_b[i] is None: + continue + self.lora_b[i] *= self.scaling[i] + self.scaling[i] = 1 + return self + + @property + def input_dim(self) -> int: + raise NotImplementedError() + + @property + def output_dim(self) -> int: + raise NotImplementedError() + + @property + def is_packed(self) -> bool: + return True diff --git a/vllm/lora/models.py b/vllm/lora/models.py new file mode 100644 index 0000000000000..6c78c4a2c7771 --- /dev/null +++ b/vllm/lora/models.py @@ -0,0 +1,654 @@ +import copy +import json +import logging +import math +import os +import re +from typing import (Any, Callable, Dict, Hashable, List, Optional, Tuple, Type, + Union) + +import safetensors.torch +import torch +from torch import nn + +from vllm.config import LoRAConfig +from vllm.utils import LRUCache, in_wsl + +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping, from_layer, from_layer_sampler +from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights +from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule + +logger = logging.getLogger(__name__) + +# TODO: The mappings below should be moved to individual model classes. + +PACKED_MODULES_CFG = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], +} + +TARGET_MODULES_QKV = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + "embed_tokens", + "lm_head", +] + +EMBEDDING_MODULES = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", +} + +EMBEDDING_PADDING_MODULES = ["lm_head"] + +_GLOBAL_LORA_ID = 0 + + +def convert_mapping( + mapping: LoRAMapping, lora_index_to_id: List[Optional[int]], + max_loras: int, vocab_size: int, extra_vocab_size: int +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, List[int]]: + """Converts LoRAMapping to index tensors. + + Args: + mapping: LoRAMapping mapping rows in a batch to LoRA ids. + lora_index_to_id: List mapping LoRA ids to LoRA indices. + max_loras: Maximum number of LoRAs. + vocab_size: Model vocab size. + extra_vocab_size: Extra vocab size each LoRA can have. + + Returns: + A tuple of tensors: + base_indices: Tensor of shape [batch_size] mapping batch rows to + LoRA indices. + sampler_indices: Tensor of shape [batch_size] mapping requests to + LoRA indices for sampler. For generation, this will be the + same as base_indicies. For prefill, this will map requests + to LoRA indices. + sampler_indices_padded: Tensor of shape [batch_size] mapping + requests to LoRA indices for sampler with padding. + Same as sampler_indicies, but -1 is replaced with + max_loras. + embeddings_indices: Tensor of shape [2, batch_size] mapping + requests to embedding indices. First row is for embeddings + added by the LoRAs, second row is for the LoRA.lora_a + embeddings. + indices_len: List of lengths of the above tensors. + """ + indices = list(mapping.index_mapping).copy() + embedding_indices = indices.copy() + lora_indices = indices.copy() + prompt_mapping = [ + lora_index_to_id.index(x) if x > 0 else -1 + for x in mapping.prompt_mapping + ] + lora_idx = None + for i in range(len(indices)): + # TODO index can be slow. optimize + lora_idx = (lora_index_to_id.index(indices[i]) + if indices[i] > 0 else -1) + embedding_indices[i] = lora_idx if indices[i] > 0 else 0 + indices[i] = i + lora_indices[i] = lora_idx + + indices = torch.tensor([indices, lora_indices, embedding_indices], + dtype=torch.long, + device="cuda") + prompt_mapping = torch.tensor(prompt_mapping, + device="cuda", + dtype=torch.long) + embeddings_indices = torch.stack([ + indices[2] * extra_vocab_size, + indices[2] * (vocab_size + extra_vocab_size) + ]) + embeddings_indices[embeddings_indices == -1] = max_loras - 1 + base_indices = indices[1] + sampler_indices = prompt_mapping + sampler_indices_padded = sampler_indices.clone() + sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1 + sampler_indices_padded = ( + torch.arange( + 0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + + (sampler_indices_padded * len(sampler_indices_padded))) + indices_len = (base_indices.shape[-1], sampler_indices.shape[-1], + sampler_indices_padded.shape[-1], + embeddings_indices.shape[-1]) + + return (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, indices_len) + + +def get_lora_id(): + global _GLOBAL_LORA_ID + _GLOBAL_LORA_ID += 1 + return _GLOBAL_LORA_ID + + +class LoRAModel: + """A LoRA fine-tuned model.""" + + def __init__( + self, + lora_model_id: int, + rank: int, + loras: Dict[str, LoRALayerWeights], + ) -> None: + self.id = lora_model_id + assert (lora_model_id > + 0), f"a valid lora id should be greater than 0, got {self.id}" + self.rank = rank + self.loras: Dict[str, LoRALayerWeights] = loras + + @property + def extra_vocab_size(self) -> int: + return max(lora.extra_vocab_size + for lora in self.loras.values()) if self.loras else 0 + + def get_lora(self, module_name: str) -> Optional[LoRALayerWeights]: + """Get LoRA for a given module by name""" + return self.loras.get(module_name, None) + + # (yard1): TODO see if we can derive target_embedding_padding automatically + @classmethod + def from_lora_tensors( + cls, + lora_model_id: int, + rank: int, + lora_alpha: int, + tensors: Dict[str, torch.Tensor], + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + embeddings: Optional[Dict[str, torch.Tensor]] = None, + target_embedding_padding: Optional[int] = None, + ) -> "LoRAModel": + """Create a LoRAModel from a dictionary of tensors.""" + pin_memory = str(device) == "cpu" and not in_wsl() + loras: Dict[str, LoRALayerWeights] = {} + for tensor_name, tensor in tensors.items(): + module_name, is_lora_a = parse_fine_tuned_lora_name(tensor_name) + if module_name not in loras: + lora_embeddings_tensor = None + if embeddings: + embeddings_module = next( + (k for k in EMBEDDING_MODULES if k in module_name), + None) + if embeddings_module: + lora_embeddings_tensor = embeddings[ + EMBEDDING_MODULES[embeddings_module]].to( + device=device, dtype=dtype) + if pin_memory: + lora_embeddings_tensor = ( + lora_embeddings_tensor.pin_memory()) + loras[module_name] = LoRALayerWeights(module_name, rank, + lora_alpha, None, None, + lora_embeddings_tensor) + if is_lora_a: + loras[module_name].lora_a = tensor.to(device=device, + dtype=dtype).t() + if pin_memory: + loras[module_name].lora_a = loras[ + module_name].lora_a.pin_memory() + else: + loras[module_name].lora_b = tensor.to(device=device, + dtype=dtype).t() + if any(name in module_name + for name in EMBEDDING_PADDING_MODULES + ) and target_embedding_padding is not None: + lora_b = loras[module_name].lora_b + assert target_embedding_padding >= lora_b.shape[1] + addition = target_embedding_padding - lora_b.shape[1] + loras[module_name].lora_b = torch.nn.functional.pad( + lora_b, (0, addition)) + if pin_memory: + loras[module_name].lora_b = loras[ + module_name].lora_b.pin_memory() + + for lora in loras.values(): + lora.optimize() + return cls(lora_model_id, rank, loras) + + @classmethod + def from_local_checkpoint( + cls, + lora_dir: str, + lora_model_id: Optional[int] = None, + device: str = "cuda", + dtype: Optional[torch.dtype] = None, + target_embedding_padding: Optional[int] = None) -> "LoRAModel": + """Create a LoRAModel from a local checkpoint.""" + lora_config_path = os.path.join(lora_dir, "adapter_config.json") + lora_tensor_path = os.path.join(lora_dir, "adapter_model.safetensors") + lora_bin_file_path = os.path.join(lora_dir, "adapter_model.bin") + new_embeddings_tensor_path = os.path.join( + lora_dir, "new_embeddings.safetensors") + new_embeddings_bin_file_path = os.path.join(lora_dir, + "new_embeddings.bin") + if os.path.isfile(lora_tensor_path): + tensors = safetensors.torch.load_file(lora_tensor_path) + elif os.path.isfile(lora_bin_file_path): + tensors = torch.load(lora_bin_file_path) + else: + raise ValueError(f"{lora_dir} doesn't contain tensors") + + embeddings = None + if os.path.isfile(new_embeddings_tensor_path): + embeddings = safetensors.torch.load_file( + new_embeddings_tensor_path) + elif os.path.isfile(new_embeddings_bin_file_path): + embeddings = torch.load(new_embeddings_bin_file_path) + + with open(lora_config_path) as f: + config = json.load(f) + rank = config["r"] + lora_alpha = config["lora_alpha"] + return cls.from_lora_tensors( + lora_model_id=get_lora_id() + if lora_model_id is None else lora_model_id, + rank=rank, + lora_alpha=lora_alpha, + tensors=tensors, + device=device, + dtype=dtype, + embeddings=embeddings, + target_embedding_padding=target_embedding_padding, + ) + + +class LoRAModelManager: + """A manager that manages multiple LoRA-fine-tuned models.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + """Create a LoRAModelManager and adapter for a given model. + + Args: + model: the model to be adapted. + max_num_seqs: the maximum number of sequences model can run in a + single batch. + max_num_batched_tokens: the maximum number of tokens model can run + in a single batch. + vocab_size: the vocab size of the model. + lora_config: the LoRA configuration. + lora_target_modules: the target modules patterns to be adapted. + Support both single module name and a list of module names. + packed_modules_mapping: the mapping for packed modules. vLLM + packs some modules into one module, e.g., qkv_proj + is packed of q_proj, k_proj, and v_proj. These modules + have a single layer in the original model, but they are split + into multiple layers in the adapted model. + """ + self.lora_config = lora_config + self.max_num_seqs = max_num_seqs + assert self.capacity >= self.lora_slots + self.max_num_batched_tokens = math.ceil(max_num_batched_tokens / 8) * 8 + self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots + self.vocab_size = vocab_size + self.base_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.embeddings_indices = torch.empty(2, + self.max_num_batched_tokens, + dtype=torch.long, + device="cuda") + self.offsets = [] + # 4 is the number of indicies tensors defined above + # base_indices, sampler_indices, sampler_indices_padded, + # embeddings_indices + self.indices_len = [None] * 4 + + self.model: nn.Module = model + self.lora_target_modules: List[str] = ([ + lora_target_modules + ] if isinstance(lora_target_modules, str) else lora_target_modules) + self.lora_target_modules = copy.deepcopy(lora_target_modules) + self.packed_modules_mapping = copy.deepcopy(packed_modules_mapping) + self.packed_modules: Dict[str, List[str]] = {} + self.modules: Dict[str, "BaseLayerWithLoRA"] = {} + self._registered_loras: Dict[int, LoRAModel] = {} + # Dict instead of a Set for compatibility with LRUCache. + self._active_loras: Dict[int, None] = {} + self._last_mapping = None + self._create_lora_modules() + self.model.lora_manager = self + + @property + def capacity(self) -> int: + return self.lora_config.max_cpu_loras + + @property + def lora_slots(self) -> int: + return self.lora_config.max_loras + + def __len__(self) -> int: + return len(self._registered_loras) + + def activate_lora( + self, + lora_id: int, + ) -> bool: + """Move LoRA into a GPU buffer to be used in the forward pass.""" + if lora_id in self._active_loras: + return False + first_free_slot = next( + ((i, lora_id) for i, lora_id in enumerate(self.lora_index_to_id) + if lora_id is None), None) + if first_free_slot is None: + raise ValueError("No free lora slots") + index, _ = first_free_slot + self._active_loras[lora_id] = None + lora_model = self._registered_loras[lora_id] + logger.debug( + f"Activating LoRA. int id: {lora_model.id}, slot index: {index}") + self.lora_index_to_id[index] = lora_model.id + for module_name, module in self.modules.items(): + module_lora = lora_model.get_lora(module_name) + if module_lora: + module_lora.optimize() + module.set_lora(index, module_lora.lora_a, module_lora.lora_b, + module_lora.embeddings_tensor) + else: + module.reset_lora(index) + return True + + def _deactivate_lora(self, lora_id: int): + try: + index = self.lora_index_to_id.index(lora_id) + self.lora_index_to_id[index] = None + except ValueError: + pass + + def deactivate_lora(self, lora_id: int) -> bool: + """Remove a LoRA from a GPU buffer.""" + if lora_id in self._active_loras: + self._deactivate_lora(lora_id) + self._active_loras.pop(lora_id) + return True + return False + + def _add_lora(self, lora: LoRAModel) -> bool: + self._create_merged_loras_inplace(lora) + self._registered_loras[lora.id] = lora + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager CPU cache.""" + if lora.id not in self._registered_loras: + if len(self._registered_loras) >= self.capacity: + raise RuntimeError("No free LoRA slots.") + self._add_lora(lora) + return True + return False + + def remove_lora(self, lora_id: int) -> bool: + """Remove a LoRAModel from the manager CPU cache.""" + # TODO: should we check active lora? + self.deactivate_lora(lora_id) + return bool(self._registered_loras.pop(lora_id, None)) + + # TODO see if this can be vectorized + def _set_lora_mapping(self, mapping: LoRAMapping) -> None: + (base_indices, sampler_indices, sampler_indices_padded, + embeddings_indices, + indices_len) = convert_mapping(mapping, self.lora_index_to_id, + self.lora_slots + 1, self.vocab_size, + self.lora_config.lora_extra_vocab_size) + self.base_indices[:base_indices.shape[0]].copy_(base_indices) + self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices) + self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_( + sampler_indices_padded) + self.embeddings_indices[:embeddings_indices. + shape[0], :embeddings_indices.shape[1]].copy_( + embeddings_indices) + # Maintain the reference + self.indices_len[:] = indices_len + + def set_lora_mapping(self, lora_mapping: LoRAMapping) -> None: + if self._last_mapping != lora_mapping: + self._set_lora_mapping(lora_mapping) + self._last_mapping = lora_mapping + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras) + + def get_lora(self, lora_id: int) -> Optional[LoRAModel]: + return self._registered_loras.get(lora_id, None) + + def remove_all_loras(self) -> bool: + """Remove all LoRAModels from the manager.""" + self._registered_loras.clear() + self.lora_index_to_id = [None] * self.lora_slots + self._active_loras.clear() + + def _create_lora_modules(self): + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name): + continue + + new_module = replace_submodule( + self.model, module_name, + from_layer(module, self.lora_slots, self.lora_config, + self.model.config)) + # (yard1): TODO make this more robust + if "lm_head" in module_name: + sampler_module = self.model.get_submodule("sampler") + new_module = replace_submodule( + self.model, "sampler", + from_layer_sampler(sampler_module, module, self.lora_slots, + self.lora_config, self.model.config)) + self.register_module(module_name, new_module) + self._register_packed_modules(module_name) + new_module.set_mapping(self.base_indices, self.sampler_indices, + self.sampler_indices_padded, + self.embeddings_indices, self.indices_len) + + def register_module(self, module_name: str, module: "BaseLayerWithLoRA"): + assert isinstance(module, BaseLayerWithLoRA) + self.modules[module_name] = module + + def create_dummy_lora(self, lora_id: int, rank: int) -> LoRAModel: + """Create zero-initialized LoRAModel for warmup.""" + model = LoRAModel(lora_id, rank, {}) + for module_name, module in self.model.named_modules(): + if not self._match_target_modules(module_name) or not isinstance( + module, BaseLayerWithLoRA): + continue + parts = module_name.split(".") + if module_name not in self.packed_modules: + if parts[-1] in EMBEDDING_MODULES: + input_dim = (module.base_layer.org_vocab_size + + self.lora_config.lora_extra_vocab_size if + hasattr(module.base_layer, "org_vocab_size") + else module.base_layer.weight.shape[1]) + output_dim = module.base_layer.embedding_dim if hasattr( + module.base_layer, + "embedding_dim") else module.base_layer.weight.shape[0] + embeddings_tensor_dim = (module.base_layer.embedding_dim if + hasattr(module.base_layer, + "embedding_dim") else + module.base_layer.weight.shape[1]) + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + input_dim, + output_dim, + rank, + module.lora_a_stacked.dtype, + "cpu", + embeddings_tensor_dim=embeddings_tensor_dim) + else: + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name, + module.lora_a_stacked.shape[-1], + module.lora_b_stacked.shape[-2], + rank, + module.lora_a_stacked.dtype, + "cpu", + ) + lora.optimize() + else: + parts = module_name.split(".") + replacements = self.packed_modules_mapping[parts[-1]] + subloras = [] + for i, r in enumerate(replacements): + lora = LoRALayerWeights.create_dummy_lora_weights( + module_name + "." + r, + module.lora_a_stacked[i].shape[-1], + module.lora_b_stacked[i].shape[-2], + rank, + module.lora_a_stacked[i].dtype, + "cpu", + ) + lora.optimize() + subloras.append(lora) + lora = PackedLoRALayerWeights.pack(subloras) + model.loras[module_name] = lora + return model + + def _match_target_modules(self, module_name: str): + return any( + re.match( + r".*\.{target_module}$".format(target_module=target_module), + module_name) or target_module == module_name + for target_module in self.lora_target_modules) + + def _register_packed_modules(self, module_full_name: str) -> None: + parts = module_full_name.split(".") + module_name = parts[-1] + replacements = self.packed_modules_mapping.get(module_name) + if not replacements: + return + prefix = ".".join(parts[:-1]) + self.packed_modules[module_full_name] = [ + prefix + "." + r if prefix else r for r in replacements + ] + + def _create_merged_loras_inplace(self, lora_model: LoRAModel) -> None: + for module_name, new_module_names in self.packed_modules.items(): + replacement_loras = [] + has_replacement = False + for r in new_module_names: + lora = lora_model.get_lora(r) + replacement_loras.append(lora) + if lora: + has_replacement = True + if not has_replacement: + continue + for i in range(len(replacement_loras)): + if replacement_loras[i]: + continue + replacement_loras[i] = None + lora_model.loras[module_name] = PackedLoRALayerWeights.pack( + replacement_loras) + + +class LoRALRUCache(LRUCache): + + def __init__(self, capacity: int, deactivate_lora_fn: Callable[[Hashable], + None]): + super().__init__(capacity) + self.deactivate_lora_fn = deactivate_lora_fn + + def _on_remove(self, key: Hashable, value: Any): + logger.debug(f"Removing LoRA. int id: {key}") + self.deactivate_lora_fn(key) + return super()._on_remove(key, value) + + +class LRUCacheLoRAModelManager(LoRAModelManager): + """A model manager that manages multiple LoRAs with LRU cache.""" + + def __init__( + self, + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + lora_target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + packed_modules_mapping: Dict[str, List[str]] = PACKED_MODULES_CFG, + ): + super().__init__(model, max_num_seqs, max_num_batched_tokens, + vocab_size, lora_config, lora_target_modules, + packed_modules_mapping) + self._registered_loras: LoRALRUCache = LoRALRUCache( + self.capacity, self.deactivate_lora) + self._active_loras: LoRALRUCache = LoRALRUCache( + self.lora_slots, self._deactivate_lora) + + def list_loras(self) -> Dict[int, LoRAModel]: + """List all registered LoRAModels.""" + return dict(self._registered_loras.cache) + + def add_lora(self, lora: LoRAModel) -> bool: + """Add a LoRAModel to the manager.""" + if lora.id not in self._registered_loras: + self._add_lora(lora) + was_added = True + else: + # We always touch to update the LRU cache order + self._registered_loras.touch(lora.id) + was_added = False + return was_added + + def activate_lora( + self, + lora_id: int, + ) -> bool: + if lora_id not in self._active_loras and len( + self._active_loras) >= self.lora_slots: + self._active_loras.remove_oldest() + result = super().activate_lora(lora_id) + # We always touch to update the LRU cache order + self._active_loras.touch(lora_id) + return result + + def remove_oldest_lora(self) -> bool: + if len(self._registered_loras) > 0: + self._registered_loras.remove_oldest() + return True + return False + + +def create_lora_manager( + model: nn.Module, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager, + **kwargs) -> LoRAModelManager: + """Create a LoRA adapter for a given model.""" + if not getattr(model, "supports_lora", False): + raise ValueError(f"Model {type(model)} is not supported for LoRA.") + lora_manager = lora_manager_cls( + model=model, + max_num_seqs=max_num_seqs, + max_num_batched_tokens=max_num_batched_tokens, + vocab_size=vocab_size, + lora_config=lora_config, + lora_target_modules=target_modules, + **kwargs) + return lora_manager diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py new file mode 100644 index 0000000000000..ac96931b2d071 --- /dev/null +++ b/vllm/lora/punica.py @@ -0,0 +1,173 @@ +# Based on code from https://github.com/punica-ai/punica + +from typing import Optional + +import torch + +import_exc = None + +try: + import vllm._punica_C as punica_kernels +except ImportError as e: + import_exc = e + +if import_exc is None: + + def bgmv( + y: torch.Tensor, + x: torch.Tensor, + w_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + ): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight + matrices. + indicies: Shape: `[B]`. Indices of the weight matrices. + layer_idx: Layer index of the weight matrices. + scale: Scaling factor. + """ + punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) + + def add_lora(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + *, + buffer: Optional[torch.Tensor] = None): + """ + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + buffer: Optional. Shape: `[B, R]`. Temporary buffer. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical innacuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, + 1.0) + punica_kernels.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, + scale) + + def add_lora_slice(y: torch.Tensor, + x: torch.Tensor, + wa_t_all: torch.Tensor, + wb_t_all: torch.Tensor, + indicies: torch.LongTensor, + layer_idx: int, + scale: float, + y_offset: int, + y_slice_size: int, + *, + buffer: Optional[torch.Tensor] = None): + """ + Same as `add_lora` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed + LoRA A matrices. + wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed + LoRA B matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + r = wb_t_all.size(-1) + if buffer is None: + # We set the buffer to be float32 by default to avoid + # numerical inaccuracies that would otherwise happen + # due to downcasting. + buffer = torch.zeros((x.size(0), r), + dtype=torch.float32, + device=x.device) + punica_kernels.dispatch_bgmv_low_level( + buffer, + x, + wa_t_all, + indicies, + layer_idx, + 1.0, + x.size(1), + buffer.size(1), + 0, + ) + punica_kernels.dispatch_bgmv_low_level( + y, + buffer, + wb_t_all, + indicies, + layer_idx, + scale, + buffer.size(1), + y_slice_size, + y_offset, + ) + +else: + + def _raise_exc( + *args, # pylint: disable=unused-argument + **kwargs # pylint: disable=unused-argument + ): + if torch.cuda.get_device_capability() < (8, 0): + raise ImportError( + "LoRA kernels require compute capability>=8.0") from import_exc + else: + raise import_exc + + bgmv = _raise_exc + add_lora = _raise_exc + add_lora_slice = _raise_exc + +__all__ = [ + "bgmv", + "add_lora", + "add_lora_slice", +] diff --git a/vllm/lora/request.py b/vllm/lora/request.py new file mode 100644 index 0000000000000..bbbf4880ab81b --- /dev/null +++ b/vllm/lora/request.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass + + +@dataclass +class LoRARequest: + """ + Request for a LoRA adapter. + + Note that this class should be be used internally. For online + serving, it is recommended to not allow users to use this class but + instead provide another layer of abstraction to prevent users from + accessing unauthorized LoRA adapters. + + lora_int_id must be globally unique for a given adapter. + This is currently not enforced in vLLM. + """ + + lora_name: str + lora_int_id: int + lora_local_path: str + + def __post_init__(self): + if self.lora_int_id < 1: + raise ValueError( + f"lora_int_id must be > 0, got {self.lora_int_id}") + + def __eq__(self, value: object) -> bool: + return isinstance( + value, LoRARequest) and self.lora_int_id == value.lora_int_id + + def __hash__(self) -> int: + return self.lora_int_id diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py new file mode 100644 index 0000000000000..f67a3812fb046 --- /dev/null +++ b/vllm/lora/utils.py @@ -0,0 +1,39 @@ +import logging +from typing import Tuple + +from torch import nn + +logger = logging.getLogger(__name__) + + +def replace_submodule(model: nn.Module, module_name: str, + new_module: nn.Module) -> nn.Module: + """Replace a submodule in a model with a new module.""" + parent = model.get_submodule(".".join(module_name.split(".")[:-1])) + target_name = module_name.split(".")[-1] + setattr(parent, target_name, new_module) + return new_module + + +def parse_fine_tuned_lora_name(name: str) -> Tuple[str, bool]: + """Parse the name of lora weights. + + args: + name: the name of the fine-tuned LoRA, e.g. + base_model.model.dense1.weight + return: + Tuple(module_name, is_lora_a): + module_name: the name of the module, e.g. model.dense1, + is_lora_a whether the tensor is lora_a or lora_b. + """ + parts = name.split(".") + assert parts[0] == "base_model" + assert parts[1] == "model" + if parts[-1] == "weight": + assert parts[-2] == "lora_A" or parts[-2] == "lora_B" + return ".".join(parts[2:-2]), parts[-2] == "lora_A" + + if parts[-1] == "lora_embedding_A" or parts[-1] == "lora_embedding_B": + return ".".join(parts[2:-1]), parts[-1] == "lora_embedding_A" + + raise ValueError(f"{name} is unsupported format") diff --git a/vllm/lora/worker_manager.py b/vllm/lora/worker_manager.py new file mode 100644 index 0000000000000..a507c08588dad --- /dev/null +++ b/vllm/lora/worker_manager.py @@ -0,0 +1,237 @@ +import logging +from abc import ABC, abstractmethod, abstractproperty +from typing import Any, List, Optional, Set, Type, Union + +import torch + +from vllm.lora.models import (TARGET_MODULES_QKV, LoRAModel, LoRAModelManager, + LRUCacheLoRAModelManager, create_lora_manager) +from vllm.lora.request import LoRARequest +from vllm.lora.layers import LoRAMapping +from vllm.config import LoRAConfig + +logger = logging.getLogger(__name__) + + +class WorkerLoRAManager(ABC): + """Abstract class for managing LoRA models on the worker side.""" + + def __init__(self, max_num_seqs: int, max_num_batched_tokens: int, + vocab_size: int, lora_config: LoRAConfig, + device: torch.device): + self.max_num_seqs = max_num_seqs + self.max_num_batched_tokens = max_num_batched_tokens + self.vocab_size = vocab_size + self.device = device + self.lora_config = lora_config + + @abstractproperty + def is_enabled(self) -> bool: + ... + + @abstractmethod + def create_lora_manager( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + ... + + @abstractmethod + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + ... + + @abstractmethod + def add_lora(self, lora_request: LoRARequest) -> bool: + ... + + @abstractmethod + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + ... + + @abstractmethod + def remove_lora(self, lora_id: int) -> bool: + ... + + @abstractmethod + def remove_all_loras(self) -> bool: + ... + + @abstractmethod + def list_loras(self) -> Set[int]: + ... + + +class WorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Every request, the requested LoRAs will be loaded (unless they are already + loaded), and every other LoRA will be unloaded.""" + + _lora_manager_cls: Type[LoRAModelManager] = LoRAModelManager + + def __init__( + self, + max_num_seqs: int, + max_num_batched_tokens: int, + vocab_size: int, + lora_config: LoRAConfig, + device: torch.device, + lora_model_cls: Type[LoRAModel] = LoRAModel, + ): + self._lora_manager: Optional[LoRAModelManager] = None + self._lora_model_cls = lora_model_cls + super().__init__(max_num_seqs, max_num_batched_tokens, vocab_size, + lora_config, device) + + @property + def is_enabled(self) -> bool: + return True + + def create_lora_manager( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_manager( + model, + max_num_seqs=self.max_num_seqs, + max_num_batched_tokens=self.max_num_batched_tokens, + target_modules=target_modules, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + lora_manager_cls=self._lora_manager_cls, + ) + self._lora_manager: LoRAModelManager = lora_manager + return lora_manager.model + + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + self._apply_loras(lora_requests) + self._lora_manager.set_lora_mapping(lora_mapping) + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_that_exist = self.list_loras() + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.lora_slots}).") + + new_loras = set(loras_map) + loras_to_add = new_loras - loras_that_exist + loras_to_remove = loras_that_exist - new_loras + + for lora_id in loras_to_remove: + self.remove_lora(lora_id) + + for lora_id in loras_to_add: + self.add_lora(loras_map[lora_id]) + + def _load_lora(self, lora_request: LoRARequest) -> LoRAModel: + try: + lora = self._lora_model_cls.from_local_checkpoint( + lora_request.lora_local_path, + lora_model_id=lora_request.lora_int_id, + device="cpu", + dtype=self.lora_config.lora_dtype, + target_embedding_padding=self.vocab_size + + self.lora_config.lora_extra_vocab_size, + ) + except Exception as e: + raise RuntimeError( + f"Loading lora {lora_request.lora_local_path} failed") from e + if lora.rank > self.lora_config.max_lora_rank: + raise ValueError( + f"LoRA rank {lora.rank} is greater than max_lora_rank " + f"{self.lora_config.max_lora_rank}.") + if lora.extra_vocab_size > self.lora_config.lora_extra_vocab_size: + raise ValueError( + f"LoRA added vocab size {lora.extra_vocab_size} is greater than " + f"lora_extra_vocab_size {self.lora_config.lora_extra_vocab_size}." + ) + return lora + + def add_dummy_lora(self, lora_request: LoRARequest, rank: int) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + return self._lora_manager.add_lora( + self._lora_manager.create_dummy_lora(lora_request.lora_int_id, + rank)) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id in self.list_loras(): + return False + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + self._lora_manager.activate_lora(lora.id) + return loaded + + def remove_lora(self, lora_id: int) -> bool: + return self._lora_manager.remove_lora(lora_id) + + def remove_all_loras(self) -> bool: + self._lora_manager.remove_all_loras() + + def list_loras(self) -> Set[int]: + return set(self._lora_manager.list_loras()) + + +class LRUCacheWorkerLoRAManager(WorkerLoRAManager): + """WorkerLoRAManager that manages LoRA models on the worker side. + + Uses an LRU Cache. Every request, the requested LoRAs will be loaded + (unless they are already loaded) and least recently used LoRAs will + be unloaded if the cache is above capacity.""" + + _lora_manager_cls: Type[ + LRUCacheLoRAModelManager] = LRUCacheLoRAModelManager + + def create_lora_manager( + self, + model: torch.nn.Module, + target_modules: Union[str, List[str]] = TARGET_MODULES_QKV, + ) -> Any: + lora_manager = create_lora_manager( + model, + target_modules=target_modules, + lora_manager_cls=self._lora_manager_cls, + max_num_seqs=self.max_num_seqs, + vocab_size=self.vocab_size, + lora_config=self.lora_config, + max_num_batched_tokens=self.max_num_batched_tokens, + ) + self._lora_manager: LRUCacheLoRAModelManager = lora_manager + return lora_manager.model + + def _apply_loras(self, lora_requests: List[LoRARequest]) -> None: + loras_map = { + lora_request.lora_int_id: lora_request + for lora_request in lora_requests if lora_request + } + if len(loras_map) > self._lora_manager.lora_slots: + raise RuntimeError( + f"Number of requested LoRAs ({len(loras_map)}) is greater " + "than the number of GPU LoRA slots " + f"({self._lora_manager.lora_slots}).") + for lora in loras_map.values(): + self.add_lora(lora) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if lora_request.lora_int_id not in self.list_loras(): + # Remove before we load the new lora to save memory + if len(self._lora_manager) + 1 > self._lora_manager.capacity: + self._lora_manager.remove_oldest_lora() + lora = self._load_lora(lora_request) + loaded = self._lora_manager.add_lora(lora) + else: + # If the lora is already loaded, just touch it to + # update its position in the caches + loaded = self._lora_manager.get_lora(lora_request.lora_int_id) + self._lora_manager.activate_lora(lora_request.lora_int_id) + return loaded diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index e8b1d3e570ff9..bc86a916b5bbf 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -27,9 +27,25 @@ class Sampler(nn.Module): parameters (e.g., sampling method, temperature, top-p, top-k, etc.). """ - def __init__(self, vocab_size: int) -> None: + def __init__(self, + vocab_size: int, + org_vocab_size: Optional[int] = None) -> None: super().__init__() self.vocab_size = vocab_size + # original vocabulary size (without LoRA). + self.org_vocab_size = org_vocab_size or vocab_size + + def _get_logits(self, hidden_states: torch.Tensor, embedding: torch.Tensor, + embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + # Get the logits for the next tokens. + logits = torch.matmul(hidden_states, embedding.t()) + if embedding_bias is not None: + logits += embedding_bias + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits def forward( self, @@ -42,8 +58,7 @@ def forward( hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) # Get the logits for the next tokens. - logits = _get_logits(hidden_states, embedding, embedding_bias, - self.vocab_size) + logits = self._get_logits(hidden_states, embedding, embedding_bias) # Only perform sampling in the driver worker. # Note: `_get_logits` is still distributed across TP workers because @@ -98,20 +113,6 @@ def forward( prompt_logprobs, sample_logprobs) -def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor, - embedding_bias: Optional[torch.Tensor], - vocab_size: int) -> Optional[torch.Tensor]: - # Get the logits for the next tokens. - logits = torch.matmul(hidden_states, embedding.t()) - if embedding_bias is not None: - logits += embedding_bias - logits = tensor_model_parallel_gather(logits) - # Remove paddings in vocab (if any). - if logits is not None: - logits = logits[:, :vocab_size] - return logits - - def _prune_hidden_states( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, diff --git a/vllm/model_executor/layers/vocab_parallel_embedding.py b/vllm/model_executor/layers/vocab_parallel_embedding.py index b08d5555b0faa..9c5fb890251ed 100644 --- a/vllm/model_executor/layers/vocab_parallel_embedding.py +++ b/vllm/model_executor/layers/vocab_parallel_embedding.py @@ -13,8 +13,11 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.utils import set_weight_attrs +DEFAULT_VOCAB_PADDING_SIZE = 64 -def pad_vocab_size(vocab_size: int, pad_to: int = 64) -> int: + +def pad_vocab_size(vocab_size: int, + pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int: """Pad the vocab size to the given value.""" return ((vocab_size + pad_to - 1) // pad_to) * pad_to @@ -43,17 +46,23 @@ class VocabParallelEmbedding(torch.nn.Module): num_embeddings: vocabulary size. embedding_dim: size of hidden state. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. """ def __init__(self, num_embeddings: int, embedding_dim: int, - params_dtype: Optional[torch.dtype] = None): + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): super().__init__() # Keep the input dimensions. self.num_embeddings = num_embeddings - self.num_embeddings_padded = pad_vocab_size(num_embeddings) + self.org_vocab_size = org_num_embeddings or num_embeddings + self.num_embeddings_padded = pad_vocab_size(num_embeddings, + padding_size) self.embedding_dim = embedding_dim if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -77,7 +86,7 @@ def __init__(self, def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): parallel_dim = param.parallel_dim - assert loaded_weight.shape[parallel_dim] == self.num_embeddings + assert loaded_weight.shape[parallel_dim] == self.org_vocab_size loaded_weight = loaded_weight[self.vocab_start_index:self. vocab_end_index] param[:loaded_weight.shape[0]].data.copy_(loaded_weight) @@ -114,14 +123,19 @@ class ParallelLMHead(VocabParallelEmbedding): embedding_dim: size of hidden state. bias: whether to use bias. params_dtype: type of the parameters. + org_num_embeddings: original vocabulary size (without LoRA). + padding_size: padding size for the vocabulary. """ def __init__(self, num_embeddings: int, embedding_dim: int, bias: bool = False, - params_dtype: Optional[torch.dtype] = None): - super().__init__(num_embeddings, embedding_dim, params_dtype) + params_dtype: Optional[torch.dtype] = None, + org_num_embeddings: Optional[int] = None, + padding_size: int = DEFAULT_VOCAB_PADDING_SIZE): + super().__init__(num_embeddings, embedding_dim, params_dtype, + org_num_embeddings, padding_size) if bias: self.bias = Parameter( torch.empty(self.num_embeddings_per_partition, diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 37543d8c9838e..0f1125e5c8e3e 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -1,12 +1,12 @@ """Utilities for selecting and loading models.""" import contextlib -from typing import Type +from typing import Optional, Type import torch import torch.nn as nn from transformers import PretrainedConfig -from vllm.config import ModelConfig +from vllm.config import ModelConfig, LoRAConfig from vllm.model_executor.models import ModelRegistry from vllm.model_executor.weight_utils import (get_quant_config, initialize_dummy_weights) @@ -32,7 +32,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: f"Supported architectures: {ModelRegistry.get_supported_archs()}") -def get_model(model_config: ModelConfig) -> nn.Module: +def get_model(model_config: ModelConfig, + lora_config: Optional[LoRAConfig] = None) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. @@ -62,7 +63,17 @@ def get_model(model_config: ModelConfig) -> nn.Module: # Create a model instance. # The weights will be initialized as empty tensors. with torch.device("cuda"): - model = model_class(model_config.hf_config, linear_method) + if getattr(model_class, "supports_lora", False): + model = model_class(model_config.hf_config, linear_method, + lora_config) + elif lora_config: + raise ValueError( + f"Model {model_class.__name__} does not support LoRA, " + "but LoRA is enabled. Support for this model may " + "be added in the future. If this is important to you, " + "please open an issue on github.") + else: + model = model_class(model_config.hf_config, linear_method) if model_config.load_format == "dummy": # NOTE(woosuk): For accurate performance evaluation, we assign # random values to the weights. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3791aa893893a..e5a1abebf1420 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -38,13 +38,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -225,14 +226,19 @@ def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ LlamaDecoderLayer(config, linear_method) @@ -263,18 +269,31 @@ def forward( class LlamaForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: LlamaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = LlamaModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = LlamaModel(config, linear_method, lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/model_executor/models/mistral.py b/vllm/model_executor/models/mistral.py index 70d033fec69fc..01cde67844122 100644 --- a/vllm/model_executor/models/mistral.py +++ b/vllm/model_executor/models/mistral.py @@ -38,13 +38,14 @@ from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, ParallelLMHead, DEFAULT_VOCAB_PADDING_SIZE) from vllm.model_executor.parallel_utils.parallel_state import ( get_tensor_model_parallel_world_size) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -220,15 +221,20 @@ def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.padding_idx = config.pad_token_id - self.vocab_size = config.vocab_size + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding( - config.vocab_size, + self.vocab_size, config.hidden_size, + org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ MistralDecoderLayer(config, linear_method) @@ -259,18 +265,33 @@ def forward( class MistralForCausalLM(nn.Module): + supports_lora = True def __init__( self, config: MistralConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.linear_method = linear_method - self.model = MistralModel(config, linear_method) - self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) - self.sampler = Sampler(config.vocab_size) + self.model = MistralModel(config, + linear_method, + lora_config=lora_config) + unpadded_vocab_size = config.vocab_size + if lora_config: + unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + ) + self.sampler = Sampler(unpadded_vocab_size, config.vocab_size) def forward( self, diff --git a/vllm/model_executor/parallel_utils/parallel_state.py b/vllm/model_executor/parallel_utils/parallel_state.py index ecc94f0252348..46bff7e16b79f 100644 --- a/vllm/model_executor/parallel_utils/parallel_state.py +++ b/vllm/model_executor/parallel_utils/parallel_state.py @@ -195,10 +195,14 @@ def get_pipeline_model_parallel_prev_rank(): def destroy_model_parallel(): - """Set the groups to none.""" + """Set the groups to none and destroy them.""" global _TENSOR_MODEL_PARALLEL_GROUP + if _TENSOR_MODEL_PARALLEL_GROUP: + torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP) _TENSOR_MODEL_PARALLEL_GROUP = None global _PIPELINE_MODEL_PARALLEL_GROUP + if _PIPELINE_MODEL_PARALLEL_GROUP: + torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP) _PIPELINE_MODEL_PARALLEL_GROUP = None global _PIPELINE_GLOBAL_RANKS _PIPELINE_GLOBAL_RANKS = None diff --git a/vllm/outputs.py b/vllm/outputs.py index fe54926e06e64..534e9d5ea8a53 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -2,6 +2,7 @@ from vllm.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, SequenceStatus) +from vllm.lora.request import LoRARequest class CompletionOutput: @@ -16,6 +17,7 @@ class CompletionOutput: logprobs: The log probabilities of the top probability words at each position if the logprobs are requested. finish_reason: The reason why the sequence is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -26,6 +28,7 @@ def __init__( cumulative_logprob: float, logprobs: Optional[SampleLogprobs], finish_reason: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, ) -> None: self.index = index self.text = text @@ -33,6 +36,7 @@ def __init__( self.cumulative_logprob = cumulative_logprob self.logprobs = logprobs self.finish_reason = finish_reason + self.lora_request = lora_request def finished(self) -> bool: return self.finish_reason is not None @@ -56,6 +60,7 @@ class RequestOutput: prompt_logprobs: The log probabilities to return per prompt token. outputs: The output sequences of the request. finished: Whether the whole request is finished. + lora_request: The LoRA request that was used to generate the output. """ def __init__( @@ -66,6 +71,7 @@ def __init__( prompt_logprobs: Optional[PromptLogprobs], outputs: List[CompletionOutput], finished: bool, + lora_request: Optional[LoRARequest] = None, ) -> None: self.request_id = request_id self.prompt = prompt @@ -73,6 +79,7 @@ def __init__( self.prompt_logprobs = prompt_logprobs self.outputs = outputs self.finished = finished + self.lora_request = lora_request @classmethod def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": @@ -108,8 +115,13 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": prompt_token_ids = seq_group.prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() - return cls(seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished) + return cls(seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + lora_request=seq_group.lora_request) def __repr__(self) -> str: return (f"RequestOutput(request_id={self.request_id}, " @@ -117,4 +129,5 @@ def __repr__(self) -> str: f"prompt_token_ids={self.prompt_token_ids}, " f"prompt_logprobs={self.prompt_logprobs}, " f"outputs={self.outputs}, " - f"finished={self.finished})") + f"finished={self.finished}, " + f"lora_request={self.lora_request})") diff --git a/vllm/prefix.py b/vllm/prefix.py index 06b5b32a38fcc..5b6e8e4b92be6 100644 --- a/vllm/prefix.py +++ b/vllm/prefix.py @@ -74,13 +74,14 @@ def _truncate_token_ids(self, token_ids: Sequence[int]) -> Tuple[int]: new_length = len(token_ids) // self.block_size * self.block_size return tuple(token_ids[:new_length]) - def add_or_get_prefix(self, token_ids: Sequence[int]) -> Optional[Prefix]: + def add_or_get_prefix(self, token_ids: Sequence[int], + lora_int_id: int) -> Optional[Prefix]: token_ids = self._truncate_token_ids(token_ids) if len(token_ids) == 0: # Prefix is empty. return None prefix = Prefix(token_ids, self.block_size) - prefix_hash = hash(prefix) + prefix_hash = hash((prefix, lora_int_id)) if prefix_hash not in self.prefixes: self.prefixes[prefix_hash] = prefix return self.prefixes[prefix_hash] diff --git a/vllm/sequence.py b/vllm/sequence.py index ca647afce9f1e..d28627f47498f 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -6,6 +6,7 @@ from vllm.block import LogicalTokenBlock from vllm.prefix import Prefix from vllm.sampling_params import SamplingParams +from vllm.lora.request import LoRARequest PromptLogprobs = List[Optional[Dict[int, float]]] SampleLogprobs = List[Dict[int, float]] @@ -106,6 +107,7 @@ class Sequence: prompt_token_ids: The token IDs of the prompt. block_size: The block size of the sequence. Should be the same as the block size used by the block manager and cache engine. + lora_request: LoRA request. """ def __init__( @@ -114,10 +116,12 @@ def __init__( prompt: str, prompt_token_ids: List[int], block_size: int, + lora_request: Optional[LoRARequest] = None, ) -> None: self.seq_id = seq_id self.prompt = prompt self.block_size = block_size + self.lora_request = lora_request self.data = SequenceData(prompt_token_ids) self.output_logprobs: SampleLogprobs = [] @@ -134,6 +138,10 @@ def __init__( # Input + output tokens self.tokens: Optional[List[str]] = None + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def _append_logical_block(self) -> None: block = LogicalTokenBlock( block_number=len(self.logical_token_blocks), @@ -229,6 +237,7 @@ class SequenceGroup: seqs: The list of sequences. sampling_params: The sampling parameters used to generate the outputs. arrival_time: The arrival time of the request. + lora_request: LoRA request. prefix: The prefix of the prompt of the sequence group. """ @@ -238,12 +247,14 @@ def __init__( seqs: List[Sequence], sampling_params: SamplingParams, arrival_time: float, + lora_request: Optional[LoRARequest] = None, prefix: Optional[Prefix] = None, ) -> None: self.request_id = request_id self.seqs_dict = {seq.seq_id: seq for seq in seqs} self.sampling_params = sampling_params self.arrival_time = arrival_time + self.lora_request = lora_request self.prefix: Optional[Prefix] = prefix self.prompt_logprobs: Optional[PromptLogprobs] = None @@ -259,6 +270,10 @@ def prompt_token_ids(self) -> List[int]: # We use the prompt of an arbitrary sequence. return next(iter(self.seqs_dict.values())).data.prompt_token_ids + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" @@ -338,6 +353,7 @@ class SequenceGroupMetadata: sampling_params: The sampling parameters used to generate the outputs. block_tables: The block tables. (Seq id -> list of physical block numbers) + lora_request: LoRA request. prefix: The prefix of the prompt of the sequence group. """ @@ -348,6 +364,7 @@ def __init__( seq_data: Dict[int, SequenceData], sampling_params: SamplingParams, block_tables: Dict[int, List[int]], + lora_request: Optional[LoRARequest] = None, prefix: Optional[Prefix] = None, ) -> None: self.request_id = request_id @@ -355,8 +372,13 @@ def __init__( self.seq_data = seq_data self.sampling_params = sampling_params self.block_tables = block_tables + self.lora_request = lora_request self.prefix = prefix + @property + def lora_int_id(self) -> int: + return self.lora_request.lora_int_id if self.lora_request else 0 + class SequenceOutput: """The model output associated with a sequence. diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index a67d2f83a2549..6edc225cdfc80 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -4,6 +4,8 @@ PreTrainedTokenizerFast) from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.utils import make_async, LRUCache from vllm.transformers_utils.tokenizers import * logger = init_logger(__name__) @@ -65,6 +67,84 @@ def get_tokenizer( return tokenizer +def get_lora_tokenizer(lora_request: LoRARequest, *args, + **kwargs) -> Optional[PreTrainedTokenizer]: + if lora_request is None: + return None + try: + tokenizer = get_tokenizer(lora_request.lora_local_path, *args, + **kwargs) + except OSError as e: + # No tokenizer was found in the LoRA folder, + # use base model tokenizer + logger.warning( + f"No tokenizer found in {lora_request.lora_local_path}, " + "using base model tokenizer instead. " + f"(Exception: {str(e)})") + tokenizer = None + return tokenizer + + +get_lora_tokenizer_async = make_async(get_lora_tokenizer) + + +class TokenizerGroup: + """A group of tokenizers that can be used for LoRA adapters.""" + + def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int, + max_input_length: Optional[int], **tokenizer_config): + self.tokenizer_id = tokenizer_id + self.tokenizer_config = tokenizer_config + self.enable_lora = enable_lora + self.max_input_length = max_input_length + self.tokenizer = get_tokenizer(self.tokenizer_id, **tokenizer_config) + if enable_lora: + self.lora_tokenizers = LRUCache(capacity=max_num_seqs) + else: + self.lora_tokenizers = None + + def encode(self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = self.get_lora_tokenizer(lora_request) + return tokenizer.encode(prompt) + + async def encode_async( + self, + prompt: str, + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None) -> List[int]: + tokenizer = await self.get_lora_tokenizer_async(lora_request) + return tokenizer.encode(prompt) + + def get_lora_tokenizer( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (get_lora_tokenizer( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + async def get_lora_tokenizer_async( + self, + lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer": + if not lora_request or not self.enable_lora: + return self.tokenizer + if lora_request.lora_int_id not in self.lora_tokenizers: + tokenizer = (await get_lora_tokenizer_async( + lora_request, **self.tokenizer_config) or self.tokenizer) + self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) + return tokenizer + else: + return self.lora_tokenizers.get(lora_request.lora_int_id) + + def _convert_tokens_to_string_with_added_encoders( tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], output_tokens: List[str], diff --git a/vllm/utils.py b/vllm/utils.py index c83b05ff609c6..23b6ca320d300 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -7,6 +7,17 @@ import psutil import torch +import asyncio +from functools import partial +from typing import ( + Awaitable, + Callable, + TypeVar, +) +from collections import OrderedDict +from typing import Any, Hashable, Optional + +T = TypeVar("T") class Device(enum.Enum): @@ -28,6 +39,69 @@ def reset(self) -> None: self.counter = 0 +class LRUCache: + + def __init__(self, capacity: int): + self.cache = OrderedDict() + self.capacity = capacity + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def __getitem__(self, key: Hashable) -> Any: + return self.get(key) + + def __setitem__(self, key: Hashable, value: Any) -> None: + self.put(key, value) + + def __delitem__(self, key: Hashable) -> None: + self.pop(key) + + def touch(self, key: Hashable) -> None: + self.cache.move_to_end(key) + + def get(self, key: Hashable, default_value: Optional[Any] = None) -> int: + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + else: + value = default_value + return value + + def put(self, key: Hashable, value: Any) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + self._remove_old_if_needed() + + def _on_remove(self, key: Hashable, value: Any): + pass + + def remove_oldest(self): + if not self.cache: + return + key, value = self.cache.popitem(last=False) + self._on_remove(key, value) + + def _remove_old_if_needed(self) -> None: + while len(self.cache) > self.capacity: + self.remove_oldest() + + def pop(self, key: int, default_value: Optional[Any] = None) -> Any: + run_on_remove = key in self.cache + value = self.cache.pop(key, default_value) + if run_on_remove: + self._on_remove(key, value) + return value + + def clear(self): + while len(self.cache) > 0: + self.remove_oldest() + self.cache.clear() + + def is_hip() -> bool: return torch.version.hip is not None @@ -59,6 +133,22 @@ def in_wsl() -> bool: return "microsoft" in " ".join(uname()).lower() +def make_async(func: Callable[..., T]) -> Callable[..., Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args, **kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=None, func=p_func) + + return _async_wrapper + + def get_ip() -> str: s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index aa37facb0ff90..985115613e044 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -1,23 +1,27 @@ import time -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Tuple, Set, Union import numpy as np import torch import torch.nn as nn -from vllm.config import ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import ModelConfig, LoRAConfig, ParallelConfig, SchedulerConfig from vllm.logger import init_logger from vllm.model_executor import get_model, InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) from vllm.sampling_params import SamplingParams, SamplingType from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata +from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager +from vllm.lora.layers import LoRAMapping +from vllm.lora.request import LoRARequest from vllm.utils import in_wsl logger = init_logger(__name__) KVCache = Tuple[torch.Tensor, torch.Tensor] _PAD_SLOT_ID = -1 +LORA_WARMUP_RANK = 8 # Capture graphs for batch size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)] @@ -30,19 +34,23 @@ def __init__( model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, + lora_config: Optional[LoRAConfig], is_driver_worker: bool = False, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config + self.lora_config = lora_config self.is_driver_worker = is_driver_worker # model_config can be None in tests/samplers/test_sampler.py. # FIXME(woosuk): This is a hack to make the tests work. Refactor this. self.sliding_window = (model_config.get_sliding_window() if model_config is not None else None) + self.device = torch.device(torch.cuda.current_device()) self.model = None self.block_size = None # Set after initial profiling. + self.lora_manager = None self.graph_runners: Dict[int, CUDAGraphRunner] = {} self.graph_memory_pool = None # Set during graph capture. @@ -61,7 +69,17 @@ def __init__( self.in_wsl = in_wsl() def load_model(self) -> None: - self.model = get_model(self.model_config) + self.model = get_model(self.model_config, self.lora_config) + + vocab_size = self.model.config.vocab_size + + if self.lora_config: + self.lora_manager = LRUCacheWorkerLoRAManager( + self.scheduler_config.max_num_seqs, + self.scheduler_config.max_num_batched_tokens + + self.scheduler_config.max_paddings, vocab_size, + self.lora_config, self.device) + self.model = self.lora_manager.create_lora_manager(self.model) def set_block_size(self, block_size: int) -> None: self.block_size = block_size @@ -74,12 +92,15 @@ def set_block_size(self, block_size: int) -> None: def _prepare_prompt( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], - List[int]]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + List[int], List[int], Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() prompt_lens: List[int] = [] context_lens: List[int] = [] @@ -113,6 +134,17 @@ def _prepare_prompt( input_positions.append( list(range(prefix_len, prefix_len + len(prompt_tokens)))) + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + + lora_index_mapping.append([lora_id] * prompt_len) + lora_prompt_mapping.extend( + [lora_id] * + (prompt_len + if seq_group_metadata.sampling_params.prompt_logprobs else 1)) + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -156,6 +188,10 @@ def _prepare_prompt( max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long) + lora_index_mapping = [ + _pad_to_max(mapping, max_prompt_len, pad=0) + for mapping in lora_index_mapping + ] context_lens_tensor = torch.tensor(context_lens, dtype=torch.int, device='cuda') @@ -188,23 +224,33 @@ def _prepare_prompt( use_cuda_graph=False, ) return (input_tokens, input_positions, input_metadata, prompt_lens, - subquery_lens) + subquery_lens, lora_index_mapping, lora_prompt_mapping, + lora_requests) def _prepare_decode( self, seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, List[int], List[int], + Set[LoRARequest]]: assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] block_tables: List[List[int]] = [] + lora_index_mapping: List[int] = [] + lora_prompt_mapping: List[int] = [] + lora_requests: Set[LoRARequest] = set() for seq_group_metadata in seq_group_metadata_list: assert not seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) + lora_id = seq_group_metadata.lora_int_id + + if lora_id > 0: + lora_requests.add(seq_group_metadata.lora_request) + for seq_id in seq_ids: seq_data = seq_group_metadata.seq_data[seq_id] generation_token = seq_data.get_last_token_id() @@ -223,6 +269,8 @@ def _prepare_decode( block_offset = position % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append([slot]) + lora_index_mapping.append([lora_id]) + lora_prompt_mapping.append(lora_id) if self.sliding_window is not None: sliding_window_blocks = (self.sliding_window // @@ -287,6 +335,10 @@ def _prepare_decode( device="cuda", ) + lora_index_mapping = [ + _pad_to_max(mapping, 1, pad=0) for mapping in lora_index_mapping + ] + input_metadata = InputMetadata( is_prompt=False, slot_mapping=slot_mapping, @@ -298,7 +350,7 @@ def _prepare_decode( block_tables=block_tables, use_cuda_graph=use_captured_graph, ) - return input_tokens, input_positions, input_metadata + return input_tokens, input_positions, input_metadata, lora_index_mapping, lora_prompt_mapping, lora_requests def _prepare_sample( self, @@ -375,7 +427,8 @@ def _prepare_sample( def prepare_input_tensors( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata]: + ) -> Tuple[torch.Tensor, torch.Tensor, InputMetadata, SamplingMetadata, + Set[int], LoRAMapping]: if self.is_driver_worker: # NOTE: We assume that all sequences in the group are all prompts or # all decodes. @@ -383,16 +436,29 @@ def prepare_input_tensors( # Prepare input tensors. if is_prompt: (input_tokens, input_positions, input_metadata, prompt_lens, - subquery_lens) = self._prepare_prompt(seq_group_metadata_list) + subquery_lens, lora_index_mapping, lora_prompt_mapping, + lora_requests) = self._prepare_prompt(seq_group_metadata_list) else: - (input_tokens, input_positions, input_metadata - ) = self._prepare_decode(seq_group_metadata_list) - subquery_lens = None + (input_tokens, input_positions, input_metadata, + lora_index_mapping, lora_prompt_mapping, + lora_requests) = self._prepare_decode(seq_group_metadata_list) prompt_lens = [] + subquery_lens = None sampling_metadata = self._prepare_sample(seq_group_metadata_list, prompt_lens, subquery_lens) + if self.lora_config: + flat_lora_index_mapping = [ + item for sublist in lora_index_mapping for item in sublist + ] + lora_mapping = LoRAMapping( + flat_lora_index_mapping, + lora_prompt_mapping, + ) + else: + lora_mapping = None + # Broadcast the metadata. metadata_dict = { "input_tokens": input_tokens, @@ -408,12 +474,16 @@ def prepare_input_tensors( "use_cuda_graph": input_metadata.use_cuda_graph, "selected_token_indices": sampling_metadata.selected_token_indices, + "lora_requests": lora_requests, + "lora_mapping": lora_mapping, } broadcast_tensor_dict(metadata_dict, src=0) else: metadata_dict = broadcast_tensor_dict(src=0) input_tokens = metadata_dict["input_tokens"] input_positions = metadata_dict["input_positions"] + lora_mapping = metadata_dict["lora_mapping"] + lora_requests = metadata_dict["lora_requests"] input_metadata = InputMetadata( is_prompt=metadata_dict["is_prompt"], slot_mapping=metadata_dict["slot_mapping"], @@ -434,7 +504,7 @@ def prepare_input_tensors( perform_sampling=False, ) - return input_tokens, input_positions, input_metadata, sampling_metadata + return input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping @torch.inference_mode() def execute_model( @@ -442,8 +512,12 @@ def execute_model( seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: - input_tokens, input_positions, input_metadata, sampling_metadata = ( + input_tokens, input_positions, input_metadata, sampling_metadata, lora_requests, lora_mapping = ( self.prepare_input_tensors(seq_group_metadata_list)) + + if self.lora_config: + self.set_active_loras(lora_requests, lora_mapping) + # Execute the model. if input_metadata.use_cuda_graph: graph_batch_size = input_tokens.shape[0] @@ -472,6 +546,28 @@ def profile_run(self) -> None: max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs + # This represents the maximum number of different requests + # that will have unique loras, an therefore the max amount of memory + # consumption create dummy lora request copies from the lora request + # passed in, which contains a lora from the lora warmup path. + dummy_lora_requests = [] + dummy_lora_requests_per_seq = [] + if self.lora_config: + for idx in range(self.lora_config.max_loras): + lora_id = idx + 1 + dummy_lora_request = LoRARequest( + lora_name=f"warmup_{lora_id}", + lora_int_id=lora_id, + lora_local_path="/not/a/real/path", + ) + self.lora_manager.add_dummy_lora(dummy_lora_request, + rank=LORA_WARMUP_RANK) + dummy_lora_requests.append(dummy_lora_request) + dummy_lora_requests_per_seq = [ + dummy_lora_requests[idx % len(dummy_lora_requests)] + for idx in range(max_num_seqs) + ] + # Profile memory usage with max_num_sequences sequences and the total # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] @@ -485,6 +581,8 @@ def profile_run(self) -> None: seq_data={group_id: seq_data}, sampling_params=sampling_params, block_tables=None, + lora_request=dummy_lora_requests_per_seq[group_id] + if dummy_lora_requests_per_seq else None, ) seqs.append(seq) @@ -495,6 +593,32 @@ def profile_run(self) -> None: torch.cuda.synchronize() return + def remove_all_loras(self) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_all_loras() + + def set_active_loras(self, lora_requests: List[LoRARequest], + lora_mapping: LoRAMapping) -> None: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + self.lora_manager.set_active_loras(lora_requests, lora_mapping) + + def add_lora(self, lora_request: LoRARequest) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + if not self.lora_manager: + raise RuntimeError("LoRA is not enabled.") + return self.lora_manager.list_loras() + @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: assert not self.model_config.enforce_eager @@ -541,6 +665,13 @@ def capture_model(self, kv_caches: List[KVCache]) -> None: use_cuda_graph=True, ) + if self.lora_config: + lora_mapping = LoRAMapping( + [0] * batch_size, + [0] * batch_size, + ) + self.set_active_loras(set(), lora_mapping) + graph_runner = CUDAGraphRunner(self.model) graph_runner.capture( input_tokens[:batch_size], diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 7d99c634ded1b..845283586e147 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,12 +1,13 @@ """A GPU worker class.""" +import gc import os -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple, Set, Optional import torch import torch.distributed from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) + SchedulerConfig, LoRAConfig) from vllm.model_executor import set_random_seed from vllm.model_executor.parallel_utils.communication_op import ( broadcast_tensor_dict) @@ -15,6 +16,7 @@ from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from vllm.worker.model_runner import ModelRunner +from vllm.lora.request import LoRARequest class Worker: @@ -33,6 +35,7 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, + lora_config: Optional[LoRAConfig] = None, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -41,12 +44,16 @@ def __init__( self.local_rank = local_rank self.rank = rank self.distributed_init_method = distributed_init_method + self.lora_config = lora_config self.is_driver_worker = is_driver_worker if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - self.model_runner = ModelRunner(model_config, parallel_config, - scheduler_config, is_driver_worker) + self.model_runner = ModelRunner(model_config, + parallel_config, + scheduler_config, + lora_config=self.lora_config, + is_driver_worker=is_driver_worker) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None @@ -117,6 +124,9 @@ def profile_num_available_blocks( num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + if self.model_runner.lora_manager: + self.model_runner.remove_all_loras() + gc.collect() torch.cuda.empty_cache() return num_gpu_blocks, num_cpu_blocks @@ -199,6 +209,15 @@ def execute_model( self.gpu_cache) return output + def add_lora(self, lora_request: LoRARequest) -> bool: + return self.model_runner.add_lora(lora_request) + + def remove_lora(self, lora_id: int) -> bool: + return self.model_runner.remove_lora(lora_id) + + def list_loras(self) -> Set[int]: + return self.model_runner.list_loras() + def _init_distributed_environment( parallel_config: ParallelConfig,