From bd5b60a6a0f2bcf58f09720af276f3021d9c75fb Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 10 Mar 2024 21:24:05 -0700 Subject: [PATCH] ci: reduce binary size (#172) 1. Do not generate prefill kernels for `page_size=8` 2. Build with `-Xfatbin=-compress-all` to reduce binary size. Followup of #171 , @Qubitium the cuda architectures to be compiled could be controlled by environment variable `TORCH_CUDA_ARCH_LIST`, so I removed the gencode/archs specified in compile args. --- .github/workflows/release_wheel.yml | 2 +- python/generate_batch_paged_prefill_inst.py | 2 +- python/setup.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/.github/workflows/release_wheel.yml b/.github/workflows/release_wheel.yml index c09f0482..ce5989ad 100644 --- a/.github/workflows/release_wheel.yml +++ b/.github/workflows/release_wheel.yml @@ -18,7 +18,7 @@ on: # required: true env: - TORCH_CUDA_ARCH_LIST: "8.0 8.6 8.9 9.0+PTX" + TORCH_CUDA_ARCH_LIST: "8.0 8.9 9.0+PTX" jobs: build: diff --git a/python/generate_batch_paged_prefill_inst.py b/python/generate_batch_paged_prefill_inst.py index ab6e4500..57227a51 100644 --- a/python/generate_batch_paged_prefill_inst.py +++ b/python/generate_batch_paged_prefill_inst.py @@ -36,9 +36,9 @@ def get_cu_file_str( dtype_in, dtype_out, idtype, + page_size_choices=[1, 8, 16, 32], ): num_frags_x_choices = [1, 2] - page_size_choices = [1, 8, 16, 32] insts = "\n".join( [ """template cudaError_t BatchPrefillWithPagedKVCacheDispatched( diff --git a/python/setup.py b/python/setup.py index 03653e5f..bb3ef10e 100644 --- a/python/setup.py +++ b/python/setup.py @@ -265,6 +265,7 @@ def get_instantiation_cu() -> List[str]: dtype, dtype, idtype, + page_size_choices=[1, 16, 32], ) write_if_different(root / prefix / fname, content) @@ -378,9 +379,8 @@ def __init__(self, *args, **kwargs) -> None: str(root.resolve() / "include"), ], extra_compile_args={ - "cxx": ["-O3", "-std=c++17"], - "nvcc": ["-O3", "-std=c++17", "--threads", "8", "-gencode", "arch=compute_80,code=sm_80", - "-gencode", "arch=compute_89,code=sm_89", "-gencode", "arch=compute_90,code=sm_90"], + "cxx": ["-O3"], + "nvcc": ["-O3", "-std=c++17", "--threads", "8", "-Xfatbin", "-compress-all"], }, ) )