From 3d6520e2ebd8b0d4ce82f140226eac0f4563a0cd Mon Sep 17 00:00:00 2001 From: Sanghyuk Choi Date: Thu, 25 Apr 2024 18:43:42 +0900 Subject: [PATCH 01/10] DOC DeepSpeed and QLoRA compatibility (#1679) --- docs/source/accelerate/deepspeed.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/source/accelerate/deepspeed.md b/docs/source/accelerate/deepspeed.md index d3d6aeb714..8442306e00 100644 --- a/docs/source/accelerate/deepspeed.md +++ b/docs/source/accelerate/deepspeed.md @@ -22,8 +22,6 @@ For DeepSpeed Stage 3 + QLoRA, please refer to the section [Use PEFT QLoRA and D For confirming these observations, we ran the SFT (Supervised Fine-tuning) [offical example scripts](https://github.com/huggingface/trl/tree/main/examples) of the [Transformers Reinforcement Learning (TRL) library](https://github.com/huggingface/trl) using QLoRA + PEFT and the accelerate configs available [here](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs). We ran these experiments on a 2x NVIDIA T4 GPU. -Note DeepSpeed-Zero3 and `bitsandbytes` are currently **not** compatible. - # Use PEFT and DeepSpeed with ZeRO3 for finetuning large models on multiple devices and multiple nodes This section of guide will help you learn how to use our DeepSpeed [training script](https://github.com/huggingface/peft/blob/main/examples/sft/train.py) for performing SFT. You'll configure the script to do SFT (supervised fine-tuning) of Llama-70B model with LoRA and ZeRO-3 on 8xH100 80GB GPUs on a single machine. You can configure it to scale to multiple machines by changing the accelerate config. From 835181460c09e40a1a237251f633e168b8a0af45 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 25 Apr 2024 13:08:33 +0200 Subject: [PATCH 02/10] ENH: Add multi-backend tests for bnb (#1667) * add multi-backend tests for bnb * Create README.md * Update build_docker_images.yml --- .github/workflows/build_docker_images.yml | 56 +++++++++++++++++ .github/workflows/nightly-bnb.yml | 4 +- docker/README.md | 11 ++++ docker/peft-gpu-bnb-multi-source/Dockerfile | 68 +++++++++++++++++++++ 4 files changed, 137 insertions(+), 2 deletions(-) create mode 100644 docker/README.md create mode 100644 docker/peft-gpu-bnb-multi-source/Dockerfile diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml index dcf43ac15a..64f9448109 100644 --- a/.github/workflows/build_docker_images.yml +++ b/.github/workflows/build_docker_images.yml @@ -239,3 +239,59 @@ jobs: } env: SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} + + latest-cuda-bnb-source-multi: + name: "Latest Peft GPU + bnb (multi-backend) source [accelerate / peft / transformers source]" + runs-on: ubuntu-latest + steps: + - name: Cleanup disk + run: | + sudo ls -l /usr/local/lib/ + sudo ls -l /usr/share/ + sudo du -sh /usr/local/lib/ + sudo du -sh /usr/share/ + sudo rm -rf /usr/local/lib/android + sudo rm -rf /usr/share/dotnet + sudo du -sh /usr/local/lib/ + sudo du -sh /usr/share/ + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v1 + - name: Check out code + uses: actions/checkout@v3 + - name: Login to DockerHub + uses: docker/login-action@v1 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_PASSWORD }} + + - name: Build and Push GPU + uses: docker/build-push-action@v4 + with: + context: ./docker/peft-gpu-bnb-multi-source + push: true + tags: huggingface/peft-gpu-bnb-multi-source + + - name: Post to a Slack channel + id: slack + #uses: slackapi/slack-github-action@v1.25.0 + uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 + with: + # Slack channel id, channel name, or user id to post message. + # See also: https://api.slack.com/methods/chat.postMessage#channels + channel-id: ${{ env.CI_SLACK_CHANNEL }} + # For posting a rich message using Block Kit + payload: | + { + "text": "peft-gpu + bnb-source (latest) Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}", + "blocks": [ + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "peft-gpu + bnb-source (latest) Docker Image build result: ${{ job.status }}\n${{ github.event.pull_request.html_url || github.event.head_commit.url }}" + } + } + ] + } + env: + SLACK_BOT_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} diff --git a/.github/workflows/nightly-bnb.yml b/.github/workflows/nightly-bnb.yml index 0c800357ee..6a30fb71f6 100644 --- a/.github/workflows/nightly-bnb.yml +++ b/.github/workflows/nightly-bnb.yml @@ -18,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - docker-image-name: ["huggingface/peft-gpu-bnb-source:latest", "huggingface/peft-gpu-bnb-latest:latest"] + docker-image-name: ["huggingface/peft-gpu-bnb-source:latest", "huggingface/peft-gpu-bnb-latest:latest", "huggingface/peft-gpu-bnb-multi-source:latest"] runs-on: [self-hosted, single-gpu, nvidia-gpu, t4, ci] env: CUDA_VISIBLE_DEVICES: "0" @@ -74,7 +74,7 @@ jobs: strategy: fail-fast: false matrix: - docker-image-name: ["huggingface/peft-gpu-bnb-source:latest", "huggingface/peft-gpu-bnb-latest:latest"] + docker-image-name: ["huggingface/peft-gpu-bnb-source:latest", "huggingface/peft-gpu-bnb-latest:latest", "huggingface/peft-gpu-bnb-multi-source:latest"] runs-on: [self-hosted, multi-gpu, nvidia-gpu, t4, ci] env: CUDA_VISIBLE_DEVICES: "0,1" diff --git a/docker/README.md b/docker/README.md new file mode 100644 index 0000000000..51d9656a6f --- /dev/null +++ b/docker/README.md @@ -0,0 +1,11 @@ +# PEFT Docker images + +Here we store all PEFT Docker images used in our testing infrastructure. We use python 3.8 for now on all our images. + +- `peft-cpu`: PEFT compiled on CPU with all other HF libraries installed on main branch +- `peft-gpu`: PEFT complied for NVIDIA GPUs wih all other HF libraries installed on main branch +- `peft-gpu-bnb-source`: PEFT complied for NVIDIA GPUs with `bitsandbytes` and all other HF libraries installed from main branch +- `peft-gpu-bnb-latest`: PEFT complied for NVIDIA GPUs with `bitsandbytes` complied from main and all other HF libraries installed from latest PyPi +- `peft-gpu-bnb-multi-source`: PEFT complied for NVIDIA GPUs with `bitsandbytes` complied from `multi-backend` branch and all other HF libraries installed from main branch + +`peft-gpu-bnb-source` and `peft-gpu-bnb-multi-source` are essentially the same, with the only difference being `bitsandbytes` compiled on another branch. Make sure to propagate the changes you applied on one file to the other! diff --git a/docker/peft-gpu-bnb-multi-source/Dockerfile b/docker/peft-gpu-bnb-multi-source/Dockerfile new file mode 100644 index 0000000000..5c8724542b --- /dev/null +++ b/docker/peft-gpu-bnb-multi-source/Dockerfile @@ -0,0 +1,68 @@ +# Builds GPU docker image of PyTorch +# Uses multi-staged approach to reduce size +# Stage 1 +# Use base conda image to reduce time +FROM continuumio/miniconda3:latest AS compile-image +# Specify py version +ENV PYTHON_VERSION=3.8 +# Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +RUN apt-get update && \ + apt-get install -y curl git wget software-properties-common git-lfs && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists* + +# Install audio-related libraries +RUN apt-get update && \ + apt install -y ffmpeg + +RUN apt install -y libsndfile1-dev +RUN git lfs install + +# Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +RUN conda create --name peft python=${PYTHON_VERSION} ipython jupyter pip +RUN python3 -m pip install --no-cache-dir --upgrade pip + +# Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile +# We don't install pytorch here yet since CUDA isn't available +# instead we use the direct torch wheel +ENV PATH /opt/conda/envs/peft/bin:$PATH +# Activate our bash shell +RUN chsh -s /bin/bash +SHELL ["/bin/bash", "-c"] + +# Stage 2 +FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS build-image +COPY --from=compile-image /opt/conda /opt/conda +ENV PATH /opt/conda/bin:$PATH + +RUN chsh -s /bin/bash +SHELL ["/bin/bash", "-c"] + +# Install apt libs +RUN apt-get update && \ + apt-get install -y curl git wget cmake && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists* + +# Activate the conda env and install transformers + accelerate from source +# Also clone BNB and build it from source. +RUN source activate peft && \ + python3 -m pip install -U --no-cache-dir \ + librosa \ + "soundfile>=0.12.1" \ + scipy \ + git+https://github.com/huggingface/transformers \ + git+https://github.com/huggingface/accelerate \ + peft[test]@git+https://github.com/huggingface/peft \ + optimum \ + auto-gptq && \ + git clone https://github.com/TimDettmers/bitsandbytes && cd bitsandbytes && git checkout multi-backend-refactor && \ + cmake -B . -DCOMPUTE_BACKEND=cuda -S . && \ + cmake --build . && \ + pip install -e . && \ + pip freeze | grep bitsandbytes + +RUN echo "source activate peft" >> ~/.profile + +# Activate the virtualenv +CMD ["/bin/bash"] From 3d9529d190a02415bb5f6eba0a6a6cc00696cea5 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Thu, 25 Apr 2024 14:28:03 +0200 Subject: [PATCH 03/10] FIX / Workflow: Fix Mac-OS CI issues (#1680) * Update helpers.py * Update tests.yml * Update src/peft/helpers.py --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index abbc188c9c..01735fef68 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -33,7 +33,7 @@ jobs: strategy: matrix: python-version: ["3.8", "3.9", "3.10", "3.11"] - os: ["ubuntu-latest", "macos-latest", "windows-latest"] + os: ["ubuntu-latest", "macos-12", "windows-latest"] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v3 From f0d3c6b8923cf3e64032de1211420f135031094d Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 25 Apr 2024 15:15:57 +0200 Subject: [PATCH 04/10] FIX Use trl version of tiny random llama (#1681) Using the version from HuggingFaceM4 broke our tests because it was updated. Although the update is reverted, we still better switch to this version, which is explicitly for testing and should be stable. --- tests/test_common_gpu.py | 6 +++--- tests/test_decoder_models.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_common_gpu.py b/tests/test_common_gpu.py index 29a7b2a123..7e67d44bfd 100644 --- a/tests/test_common_gpu.py +++ b/tests/test_common_gpu.py @@ -469,7 +469,7 @@ def test_lora_seq2seq_lm_multi_gpu_inference(self): @require_bitsandbytes def test_adaption_prompt_8bit(self): model = LlamaForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", + "trl-internal-testing/tiny-random-LlamaForCausalLM", quantization_config=BitsAndBytesConfig(load_in_8bit=True), torch_dtype=torch.float16, device_map="auto", @@ -492,7 +492,7 @@ def test_adaption_prompt_8bit(self): @require_bitsandbytes def test_adaption_prompt_4bit(self): model = LlamaForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", + "trl-internal-testing/tiny-random-LlamaForCausalLM", quantization_config=BitsAndBytesConfig(load_in_4bit=True), torch_dtype=torch.float16, device_map="auto", @@ -982,7 +982,7 @@ def test_4bit_dora_merging(self): bnb_4bit_compute_dtype=torch.float32, ) model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", + "trl-internal-testing/tiny-random-LlamaForCausalLM", quantization_config=bnb_config, torch_dtype=torch.float32, ).eval() diff --git a/tests/test_decoder_models.py b/tests/test_decoder_models.py index 00e0462537..ead2c7bce1 100644 --- a/tests/test_decoder_models.py +++ b/tests/test_decoder_models.py @@ -32,7 +32,7 @@ "hf-internal-testing/tiny-random-gpt_neo", "hf-internal-testing/tiny-random-GPTJForCausalLM", "hf-internal-testing/tiny-random-GPTBigCodeForCausalLM", - "HuggingFaceM4/tiny-random-LlamaForCausalLM", + "trl-internal-testing/tiny-random-LlamaForCausalLM", ] FULL_GRID = { @@ -340,7 +340,7 @@ def test_passing_input_embeds_works(self, test_name, model_id, config_cls, confi self._test_passing_input_embeds_works(test_name, model_id, config_cls, config_kwargs) def test_lora_layer_replication(self): - model_id = "HuggingFaceM4/tiny-random-LlamaForCausalLM" + model_id = "trl-internal-testing/tiny-random-LlamaForCausalLM" config_kwargs = { "target_modules": ["down_proj", "up_proj"], "task_type": "CAUSAL_LM", From b1d6c77108a39358bffadef0481a5f38a3b6c0c3 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Thu, 25 Apr 2024 20:35:16 +0200 Subject: [PATCH 05/10] FIX Don't eagerly import bnb for LoftQ (#1683) We accidentally added code in loftq_utils.py that eagerly imports bnb, which we want to avoid to prevent CUDA from being initialized too early. --- src/peft/utils/loftq_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/peft/utils/loftq_utils.py b/src/peft/utils/loftq_utils.py index 20bbe20ada..f8323485a7 100644 --- a/src/peft/utils/loftq_utils.py +++ b/src/peft/utils/loftq_utils.py @@ -31,10 +31,6 @@ from peft.import_utils import is_bnb_4bit_available, is_bnb_available -if is_bnb_available(): - import bitsandbytes as bnb - - class NFQuantizer: def __init__(self, num_bits=2, device="cuda", method="normal", block_size=64, *args, **kwargs): super().__init__(*args, **kwargs) @@ -192,6 +188,11 @@ def _low_rank_decomposition(weight, reduced_rank=32): @torch.no_grad() def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, reduced_rank: int, num_iter=1): + if is_bnb_available(): + import bitsandbytes as bnb + else: + raise ValueError("bitsandbytes is not available, please install it to use LoftQ.") + if num_bits not in [2, 4, 8]: raise ValueError("Only support 2, 4, 8 bits quantization") if num_iter <= 0: @@ -239,6 +240,8 @@ def loftq_init(weight: Union[torch.Tensor, torch.nn.Parameter], num_bits: int, r @torch.no_grad() def _loftq_init_new(qweight, weight, num_bits: int, reduced_rank: int): + import bitsandbytes as bnb + if num_bits != 4: raise ValueError("Only 4 bit quantization supported at the moment.") if not is_bnb_4bit_available(): From d0fa70aeb6f1cf0fbe11fb8e2e6fd568d4fbdaff Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 26 Apr 2024 10:20:18 +0200 Subject: [PATCH 06/10] FEAT: Add EETQ support in PEFT (#1675) * v1 * fix tests' * fix unneeded change * fix unneeded change * fix unneeded change * fix * fix CI * fix docker image * fix docker image * add docs * lazy import * raise when merge * raise when merge * Update eetq.py * merge * style * add unmerge * indent * Update docs/source/developer_guides/quantization.md Co-authored-by: Benjamin Bossan * add details about transformers --------- Co-authored-by: Benjamin Bossan --- docker/peft-gpu/Dockerfile | 4 + docs/source/developer_guides/quantization.md | 36 +++++ src/peft/import_utils.py | 5 + src/peft/tuners/lora/__init__.py | 7 +- src/peft/tuners/lora/eetq.py | 104 +++++++++++++ src/peft/tuners/lora/layer.py | 3 + src/peft/tuners/lora/model.py | 5 +- src/peft/utils/other.py | 6 +- tests/test_gpu_examples.py | 148 +++++++++++++++++++ tests/testing_utils.py | 15 +- 10 files changed, 328 insertions(+), 5 deletions(-) create mode 100644 src/peft/tuners/lora/eetq.py diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index c82fc14239..37e5b76fc9 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -52,6 +52,10 @@ RUN apt-get update && \ apt-get clean && \ rm -rf /var/lib/apt/lists* +# Add eetq for quantization testing +RUN source activate peft && \ + python3 -m pip install git+https://github.com/NetEase-FuXi/EETQ.git + # Activate the conda env and install transformers + accelerate from source RUN source activate peft && \ python3 -m pip install -U --no-cache-dir \ diff --git a/docs/source/developer_guides/quantization.md b/docs/source/developer_guides/quantization.md index 93a6aaacbc..702dee963a 100644 --- a/docs/source/developer_guides/quantization.md +++ b/docs/source/developer_guides/quantization.md @@ -128,6 +128,42 @@ quantized_model = get_peft_model(quantized_model, peft_config) You can refer to the [Google Colab](https://colab.research.google.com/drive/12GTp1FCj5_0SnnNQH18h_2XFh9vS_guX?usp=sharing) example for an overview of AQLM+LoRA finetuning. +## EETQ quantization + +You can also perform LoRA fine-tuning on EETQ quantized models. [EETQ](https://github.com/NetEase-FuXi/EETQ) package offers simple and efficient way to perform 8-bit quantization, which is claimed to be faster than the `LLM.int8()` algorithm. First, make sure that you have a transformers version that is compatible with EETQ (e.g. by installing it from latest pypi or from source). + +```py +import torch +from transformers import EetqConfig + +config = EetqConfig("int8") +``` + +Pass the `config` to the [`~transformers.AutoModelForCausalLM.from_pretrained`] method. + +```py +from transformers import AutoModelForCausalLM + +model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-v0.1", quantization_config=config) +``` + +and create a `LoraConfig` and pass it to `get_peft_model`: + +```py +from peft import LoraConfig, get_peft_model + +config = LoraConfig( + r=16, + lora_alpha=8, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM" +) + +model = get_peft_model(model, config) +``` + ## Next steps If you're interested in learning more about quantization, the following may be helpful: diff --git a/src/peft/import_utils.py b/src/peft/import_utils.py index 6799058e0c..a1acf484ab 100644 --- a/src/peft/import_utils.py +++ b/src/peft/import_utils.py @@ -77,3 +77,8 @@ def is_aqlm_available(): @lru_cache def is_auto_awq_available(): return importlib.util.find_spec("awq") is not None + + +@lru_cache +def is_eetq_available(): + return importlib.util.find_spec("eetq") is not None diff --git a/src/peft/tuners/lora/__init__.py b/src/peft/tuners/lora/__init__.py index 3115fff724..2a0bce2a5f 100644 --- a/src/peft/tuners/lora/__init__.py +++ b/src/peft/tuners/lora/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from peft.import_utils import is_bnb_4bit_available, is_bnb_available +from peft.import_utils import is_bnb_4bit_available, is_bnb_available, is_eetq_available from .config import LoftQConfig, LoraConfig from .gptq import QuantLinear @@ -34,4 +34,9 @@ def __getattr__(name): return Linear4bit + if (name == "EetqLoraLinear") and is_eetq_available(): + from .eetq import EetqLoraLinear + + return EetqLoraLinear + raise AttributeError(f"module {__name__} has no attribute {name}") diff --git a/src/peft/tuners/lora/eetq.py b/src/peft/tuners/lora/eetq.py new file mode 100644 index 0000000000..6bf42c6814 --- /dev/null +++ b/src/peft/tuners/lora/eetq.py @@ -0,0 +1,104 @@ +# Copyright 2024-present the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, List, Optional + +import torch + +from peft.import_utils import is_eetq_available +from peft.tuners.lora.layer import LoraLayer +from peft.tuners.tuners_utils import BaseTunerLayer + + +if is_eetq_available(): + from eetq import EetqLinear + + class EetqLoraLinear(torch.nn.Module, LoraLayer): + def __init__( + self, + base_layer, + adapter_name, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + init_lora_weights: bool = True, + use_rslora: bool = False, + **kwargs, + ): + super().__init__() + LoraLayer.__init__(self, base_layer) + + # self.base_layer and self.quant_linear_module are the same; we need the former for consistency and the latter + # for backwards compatibility + self.quant_linear_module = base_layer + + self._active_adapter = adapter_name + self.update_layer(adapter_name, r, lora_alpha, lora_dropout, init_lora_weights, use_rslora) + + def forward(self, x: torch.Tensor): + result = self.quant_linear_module(x) + + if self.disable_adapters: + return result + + for active_adapter in self.active_adapters: + if active_adapter not in self.lora_A.keys(): + continue + lora_A = self.lora_A[active_adapter] + lora_B = self.lora_B[active_adapter] + dropout = self.lora_dropout[active_adapter] + scaling = self.scaling[active_adapter] + + requires_conversion = not torch.is_autocast_enabled() + if requires_conversion: + expected_dtype = result.dtype + x = x.to(lora_A.weight.dtype) + + output = lora_B(lora_A(dropout(x))) + if requires_conversion: + output = output.to(expected_dtype) + output = output * scaling + result = result + output + return result + + def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = None) -> None: + raise AttributeError("Merging LoRA layers is not supported for Eetq layers.") + + def unmerge(self) -> None: + raise AttributeError("Unmerging LoRA layers is not supported for Eetq layers.") + + def __repr__(self) -> str: + rep = super().__repr__() + return "lora." + rep + + +def dispatch_eetq( + target: torch.nn.Module, + adapter_name: str, + **kwargs: Any, +) -> Optional[torch.nn.Module]: + new_module = None + + if isinstance(target, BaseTunerLayer): + target_base_layer = target.get_base_layer() + else: + target_base_layer = target + + if is_eetq_available() and isinstance(target_base_layer, EetqLinear): + new_module = EetqLoraLinear(target, adapter_name, **kwargs) + target.weight = target_base_layer.weight + + if hasattr(target, "bias"): + target.bias = target_base_layer.bias + + return new_module diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 024257d182..689b921371 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -77,6 +77,9 @@ def __init__(self, base_layer: nn.Module, **kwargs) -> None: elif hasattr(base_layer, "w_bit") and base_layer.__class__.__name__ == "WQLinear_GEMM": # Awq layers in_features, out_features = base_layer.in_features, base_layer.out_features + elif base_layer.__class__.__name__ == "EetqLinear": + # Eetq layers + in_features, out_features = base_layer.in_features, base_layer.out_features else: raise ValueError(f"Unsupported layer type {type(base_layer)}") diff --git a/src/peft/tuners/lora/model.py b/src/peft/tuners/lora/model.py index 4ed41012d8..6b3fbc6a69 100644 --- a/src/peft/tuners/lora/model.py +++ b/src/peft/tuners/lora/model.py @@ -48,6 +48,7 @@ from .aqlm import dispatch_aqlm from .awq import dispatch_awq from .config import LoraConfig +from .eetq import dispatch_eetq from .gptq import dispatch_gptq from .layer import Conv2d, LoraLayer, dispatch_default from .tp_layer import dispatch_megatron @@ -288,7 +289,9 @@ def _create_new_module(lora_config, adapter_name, target, **kwargs): dispatchers.append(dispatch_bnb_4bit) - dispatchers.extend([dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default]) + dispatchers.extend( + [dispatch_eetq, dispatch_aqlm, dispatch_awq, dispatch_gptq, dispatch_megatron, dispatch_default] + ) new_module = None for dispatcher in dispatchers: diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 503db7d9d3..03f6507975 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -95,6 +95,8 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False) is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq" is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm" + is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq" + if gradient_checkpointing_kwargs is None: gradient_checkpointing_kwargs = {} @@ -102,7 +104,7 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad # freeze base model's layers param.requires_grad = False - if not is_gptq_quantized and not is_aqlm_quantized: + if not is_gptq_quantized and not is_aqlm_quantized and not is_eetq_quantized: # cast all non INT8 parameters to fp32 for param in model.parameters(): if ( @@ -110,7 +112,7 @@ def prepare_model_for_kbit_training(model, use_gradient_checkpointing=True, grad ) and param.__class__.__name__ != "Params4bit": param.data = param.data.to(torch.float32) - if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized) and use_gradient_checkpointing: + if (loaded_in_kbit or is_gptq_quantized or is_aqlm_quantized or is_eetq_quantized) and use_gradient_checkpointing: # When having `use_reentrant=False` + gradient_checkpointing, there is no need for this hack if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]: # For backward compatibility diff --git a/tests/test_gpu_examples.py b/tests/test_gpu_examples.py index a77b916220..91b09a0fbd 100644 --- a/tests/test_gpu_examples.py +++ b/tests/test_gpu_examples.py @@ -60,6 +60,7 @@ require_auto_awq, require_auto_gptq, require_bitsandbytes, + require_eetq, require_optimum, require_torch_gpu, require_torch_multi_gpu, @@ -2072,6 +2073,153 @@ def test_causal_lm_training_multi_gpu(self): assert trainer.state.log_history[-1]["train_loss"] is not None +@require_torch_gpu +@require_eetq +class PeftEetqGPUTests(unittest.TestCase): + r""" + EETQ + peft tests + """ + + def setUp(self): + self.causal_lm_model_id = "facebook/opt-125m" + self.tokenizer = AutoTokenizer.from_pretrained(self.causal_lm_model_id) + + def tearDown(self): + r""" + Efficient mechanism to free GPU memory after each test. Based on + https://github.com/huggingface/transformers/issues/21094 + """ + gc.collect() + torch.cuda.empty_cache() + + def _check_inference_finite(self, model, batch): + # try inference without Trainer class + training = model.training + model.eval() + output = model(**batch.to(model.device)) + assert torch.isfinite(output.logits).all() + model.train(training) + + @pytest.mark.single_gpu_tests + def test_causal_lm_training_eetq(self): + r""" + Test the CausalLM training on a single GPU device. The test would simply fail if the adapters are not set + correctly. + """ + from transformers import EetqConfig + + with tempfile.TemporaryDirectory() as tmp_dir: + quantization_config = EetqConfig("int8") + + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, device_map="auto", quantization_config=quantization_config + ) + + model = prepare_model_for_kbit_training(model) + + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + model = get_peft_model(model, config) + + data = load_dataset("ybelkada/english_quotes_copy") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + @pytest.mark.multi_gpu_tests + @require_torch_multi_gpu + def test_causal_lm_training_multi_gpu_eetq(self): + r""" + Test the CausalLM training on a multi-GPU device. The test would simply fail if the adapters are not set + correctly. + """ + from transformers import EetqConfig + + with tempfile.TemporaryDirectory() as tmp_dir: + quantization_config = EetqConfig("int8") + + model = AutoModelForCausalLM.from_pretrained( + self.causal_lm_model_id, + device_map="auto", + quantization_config=quantization_config, + ) + + assert set(model.hf_device_map.values()) == set(range(torch.cuda.device_count())) + + model = prepare_model_for_kbit_training(model) + + setattr(model, "model_parallel", True) + setattr(model, "is_parallelizable", True) + + config = LoraConfig( + r=16, + lora_alpha=32, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.05, + bias="none", + task_type="CAUSAL_LM", + ) + + model = get_peft_model(model, config) + + data = load_dataset("Abirate/english_quotes") + data = data.map(lambda samples: self.tokenizer(samples["quote"]), batched=True) + + trainer = Trainer( + model=model, + train_dataset=data["train"], + args=TrainingArguments( + per_device_train_batch_size=4, + gradient_accumulation_steps=4, + warmup_steps=2, + max_steps=3, + learning_rate=2e-4, + logging_steps=1, + output_dir=tmp_dir, + ), + data_collator=DataCollatorForLanguageModeling(self.tokenizer, mlm=False), + ) + model.config.use_cache = False + trainer.train() + + model.cpu().save_pretrained(tmp_dir) + + assert "adapter_config.json" in os.listdir(tmp_dir) + assert SAFETENSORS_WEIGHTS_NAME in os.listdir(tmp_dir) + + # assert loss is not None + assert trainer.state.log_history[-1]["train_loss"] is not None + + PRECISIONS = [(torch.float32), (torch.float16), (torch.bfloat16)] LORA_PARAMS = { diff --git a/tests/testing_utils.py b/tests/testing_utils.py index eee33b5a67..28ba3a2b32 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -18,7 +18,13 @@ import pytest import torch -from peft.import_utils import is_aqlm_available, is_auto_awq_available, is_auto_gptq_available, is_optimum_available +from peft.import_utils import ( + is_aqlm_available, + is_auto_awq_available, + is_auto_gptq_available, + is_eetq_available, + is_optimum_available, +) def require_torch_gpu(test_case): @@ -75,6 +81,13 @@ def require_auto_awq(test_case): return unittest.skipUnless(is_auto_awq_available(), "test requires auto-awq")(test_case) +def require_eetq(test_case): + """ + Decorator marking a test that requires eetq. These tests are skipped when eetq isn't installed. + """ + return unittest.skipUnless(is_eetq_available(), "test requires eetq")(test_case) + + def require_optimum(test_case): """ Decorator marking a test that requires optimum. These tests are skipped when optimum isn't installed. From 383e1fab0e4a3852cf3d90b4eb357504b04ad1f7 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 26 Apr 2024 10:48:05 +0200 Subject: [PATCH 07/10] Update build_docker_images.yml (#1682) --- .github/workflows/build_docker_images.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/build_docker_images.yml b/.github/workflows/build_docker_images.yml index 64f9448109..e53b00f59b 100644 --- a/.github/workflows/build_docker_images.yml +++ b/.github/workflows/build_docker_images.yml @@ -46,6 +46,7 @@ jobs: tags: huggingface/peft-cpu - name: Post to a Slack channel + if: always() id: slack #uses: slackapi/slack-github-action@v1.25.0 uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 @@ -102,6 +103,7 @@ jobs: tags: huggingface/peft-gpu - name: Post to a Slack channel + if: always() id: slack #uses: slackapi/slack-github-action@v1.25.0 uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 @@ -159,6 +161,7 @@ jobs: - name: Post to a Slack channel + if: always() id: slack #uses: slackapi/slack-github-action@v1.25.0 uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 @@ -216,6 +219,7 @@ jobs: tags: huggingface/peft-gpu-bnb-latest - name: Post to a Slack channel + if: always() id: slack #uses: slackapi/slack-github-action@v1.25.0 uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 @@ -272,6 +276,7 @@ jobs: tags: huggingface/peft-gpu-bnb-multi-source - name: Post to a Slack channel + if: always() id: slack #uses: slackapi/slack-github-action@v1.25.0 uses: slackapi/slack-github-action@6c661ce58804a1a20f6dc5fbee7f0381b469e001 From 8bc3c0861d4a06d1501b5e8954079804419bb0e6 Mon Sep 17 00:00:00 2001 From: Younes Belkada <49240599+younesbelkada@users.noreply.github.com> Date: Fri, 26 Apr 2024 15:49:02 +0200 Subject: [PATCH 08/10] Update Dockerfile (#1684) --- docker/peft-gpu/Dockerfile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docker/peft-gpu/Dockerfile b/docker/peft-gpu/Dockerfile index 37e5b76fc9..9cc97ead71 100644 --- a/docker/peft-gpu/Dockerfile +++ b/docker/peft-gpu/Dockerfile @@ -42,9 +42,9 @@ RUN source activate peft && \ # Add autoawq for quantization testing RUN source activate peft && \ - python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.1/autoawq-0.2.1-cp38-cp38-linux_x86_64.whl + python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ/releases/download/v0.2.4/autoawq-0.2.4-cp38-cp38-linux_x86_64.whl RUN source activate peft && \ - python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.4/autoawq_kernels-0.0.4-cp38-cp38-linux_x86_64.whl + python3 -m pip install --no-cache-dir https://github.com/casper-hansen/AutoAWQ_kernels/releases/download/v0.0.6/autoawq_kernels-0.0.6-cp38-cp38-linux_x86_64.whl # Install apt libs RUN apt-get update && \ From e7b47ac01d14c9b15fd970c29da89b779cfabfc6 Mon Sep 17 00:00:00 2001 From: Benjamin Bossan Date: Mon, 29 Apr 2024 11:35:47 +0200 Subject: [PATCH 09/10] FIX Init DoRA weights in float32 if float16 used (#1653) When DoRA weights are initialized in float16 on CPU and when an older PyTorch version is being used (<2.2), there is an error because the the operation is not supported for float16 on CPU. This commit temporarily converts the LoRA weights to float32 beforehand if they're in float16. Of course, when the user tries to train or predict with this model on CPU, they will still encounter errors. However, in certain situations, only the initialization might be on CPU and later it is moved to GPU. This could be some framework code that the user has no control over, as in #1597. Therefore, it's good to have this safety hatch. Note that since our CI uses the latest PyTorch version, we cannot run a test for this, as the latest PyTorch runs no matter what. --- src/peft/tuners/lora/layer.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/peft/tuners/lora/layer.py b/src/peft/tuners/lora/layer.py index 689b921371..56ab7c4a1a 100644 --- a/src/peft/tuners/lora/layer.py +++ b/src/peft/tuners/lora/layer.py @@ -181,19 +181,29 @@ def _get_weight_norm(self, weight, lora_weight, scaling) -> torch.Tensor: return weight_norm def dora_init(self, adapter_name: str) -> None: - lora_A = self.lora_A[adapter_name] - lora_B = self.lora_B[adapter_name] + lora_A = self.lora_A[adapter_name].weight + lora_B = self.lora_B[adapter_name].weight + # temporarily convert fp16 to fp32, as fp16 can cause trouble on CPU with PyTorch < 2.2 + dtype_is_fp16 = lora_A.dtype == torch.float16 + if dtype_is_fp16: + lora_A = lora_A.float() + lora_B = lora_B.float() + scaling = self.scaling[adapter_name] with gather_params_ctx(self.get_base_layer().parameters()): weight = self.get_base_layer().weight quant_state = getattr(self.get_base_layer(), "state", None) weight = dequantize_bnb_weight(weight, state=quant_state) # no-op if not bnb if weight.data.ndim == 4: # For handling LoRAs applied to Conv2Ds. - lora_weight = torch.mm(lora_B.weight.flatten(start_dim=1), lora_A.weight.flatten(start_dim=1)) + lora_weight = torch.mm(lora_B.flatten(start_dim=1), lora_A.flatten(start_dim=1)) lora_weight = lora_weight.reshape(weight.shape) else: - lora_weight = lora_B.weight @ lora_A.weight + lora_weight = lora_B @ lora_A + + if dtype_is_fp16: + lora_weight = lora_weight.half() weight_norm = self._get_weight_norm(weight, lora_weight, scaling) + self.lora_magnitude_vector = nn.ParameterDict() self.lora_magnitude_vector[adapter_name] = nn.Parameter(weight_norm, requires_grad=True) # add lora_magnitude_vector to the list of learnable parameters From 7a22b7daf0b3abf4cea98fb2ee8cb60f5d28cc99 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Mon, 29 Apr 2024 17:50:42 +0800 Subject: [PATCH 10/10] FIX bf16 dtype issue for IA3 (#1634) Signed-off-by: Wang, Yi A --- src/peft/tuners/ia3/layer.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/src/peft/tuners/ia3/layer.py b/src/peft/tuners/ia3/layer.py index 9ea04e6873..aef376860c 100644 --- a/src/peft/tuners/ia3/layer.py +++ b/src/peft/tuners/ia3/layer.py @@ -111,6 +111,7 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N if active_adapter in self.ia3_l.keys(): base_layer = self.get_base_layer() ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + orig_dtype = base_layer.weight.data.dtype if safe_merge: orig_weights = base_layer.weight.data orig_weights = torch.mul(orig_weights, ia3_l) @@ -119,13 +120,14 @@ def merge(self, safe_merge: bool = False, adapter_names: Optional[List[str]] = N raise ValueError( f"NaNs detected in the merged weights. The adapter {active_adapter} seems to be broken" ) - base_layer.weight.data = orig_weights + base_layer.weight.data = orig_weights.to(orig_dtype) else: - base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l) + base_layer.weight.data = torch.mul(base_layer.weight.data, ia3_l).to(orig_dtype) if not self.is_feedforward and (base_layer.bias is not None): scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) - base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data) + orig_dtype = base_layer.bias.data.dtype + base_layer.bias.data = torch.mul(base_layer.bias.data, scaling.data).to(orig_dtype) self.merged_adapters.append(active_adapter) @@ -144,15 +146,16 @@ def unmerge(self) -> None: base_layer = self.get_base_layer() # Add tolerace to avoid division by zero ia3_l = transpose(self.ia3_l[active_adapter].data, self.fan_in_fan_out) + 1e-8 - base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l) + orig_dtype = base_layer.weight.data.dtype + base_layer.weight.data = torch.div(base_layer.weight.data, ia3_l).to(orig_dtype) if not self.is_feedforward and (base_layer.bias is not None): scaling = self.ia3_l[active_adapter].reshape(base_layer.bias.shape) - base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8) + orig_dtype = base_layer.bias.data.dtype + base_layer.bias.data = torch.div(base_layer.bias.data, scaling.data + 1e-8).to(orig_dtype) def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: dtype = previous_dtype = x.dtype - if self.disable_adapters: if self.merged: self.unmerge() @@ -171,13 +174,13 @@ def forward(self, x: torch.Tensor, *args: Any, **kwargs: Any) -> torch.Tensor: x = x.to(dtype) # TODO: weight.dtype can be != self.ia3_l[self.active_adapters].dtype # e.g. bf16 vs fp32. Is that okay? - interm = (x * ia3_scaling).to(self.get_base_layer().weight.dtype) + interm = (x * ia3_scaling).to(previous_dtype) result = self.base_layer(interm, *args, **kwargs) else: result = self.base_layer(x, *args, **kwargs) - result = result.to(dtype) * ia3_scaling + result_dtype = result.dtype + result = (result * ia3_scaling).to(result_dtype) - result = result.to(previous_dtype) return result