Skip to content

Commit

Permalink
Fixing DoRA docs, adding to mem opt tutorial (#1918)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Oct 29, 2024
1 parent 1f5e21d commit 48a8449
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 15 deletions.
1 change: 1 addition & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ PEFT Components
:nosignatures:

peft.LoRALinear
peft.DoRALinear
peft.AdapterModule
peft.get_adapter_params
peft.set_trainable_params
Expand Down
1 change: 1 addition & 0 deletions docs/source/recipes/lora_finetune_single_device.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ see our documentation for the different PEFT training paradigms we support:

* :ref:`glossary_lora`
* :ref:`glossary_qlora`
* :ref:`glossary_dora`

Many of our other memory optimization features can be used in this recipe. You can learn more about all of our memory optimization features in our :ref:`memory optimization overview<memory_optimization_overview_label>`.

Expand Down
111 changes: 105 additions & 6 deletions docs/source/tutorials/memory_optimizations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ To make things easy, we've summarized these components in the following table:
":ref:`glossary_cpu_offload`", "Offloads optimizer states and (optionally) gradients to CPU, and performs optimizer steps on CPU. This can be used to significantly reduce GPU memory usage at the cost of CPU RAM and training speed, as CPU optimizer steps can be slow and bottleneck training performance."
":ref:`glossary_lora`", "When you want to significantly reduce the number of trainable parameters, saving gradient and optimizer memory during training, and significantly speeding up training."
":ref:`glossary_qlora`", "When you need even more memory savings than LoRA, at the potential cost of some training speed. Useful for very large models or limited hardware."
":ref:`glossary_dora`", "Like LoRA, DoRA can provide significant memory savings and training speed-ups. DoRA may improve performance over LoRA, particularly when using small rank updates."


.. note::
Expand Down Expand Up @@ -110,7 +111,7 @@ checkpointing, where all activations will either be recomputed later in the back

To enable activation offloading, use the ``enable_activation_offloading`` config entry or flag
in our lora finetuning single device recipe, e.g. ``enable_activation_offloading=True``. To allow
usage of streams, make sure you are on a torch version later than PyTorch 2.5.0.dev20240907.
usage of streams, make sure you are on a torch version later than PyTorch 2.5.0.

.. _glossary_grad_accm:

Expand Down Expand Up @@ -343,6 +344,7 @@ These are all specified under the ``model`` flag or config entry, i.e:
.. code-block:: yaml
model:
_component_: torchtune.models.llama3.lora_llama3_8b
apply_lora_to_mlp: True
model.lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
Expand All @@ -357,7 +359,24 @@ Secondly, parameters which control the scale of the impact of LoRA on the model:
to your specific use case. Typically, one jointly changes ``lora_rank`` and ``lora_alpha`` together, where ``lora_alpha ~= 2*lora_rank``.
* ``lora_dropout`` introduces dropout in the LoRA layers to help regularize training. We default to 0.0 for all of our models.

As above, these parameters are also specified under the ``model`` flag or config entry.
As above, these parameters are also specified under the ``model`` flag or config entry:

.. code-block:: bash
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"] \
model.lora_rank=32 \
model.lora_alpha=64
.. code-block:: yaml
model:
_component_: torchtune.models.llama3.lora_llama3_8b
apply_lora_to_mlp: True
lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
lora_rank: 32
lora_alpha: 64
.. note::

Expand Down Expand Up @@ -388,18 +407,98 @@ You can finetune using QLoRA with any of our LoRA recipes, i.e. recipes with the
QLoRA-enabled model builders, which we support for all our models, and also use the ``qlora_`` prefix, e.g.
the :func:`torchtune.models.llama3.llama3_8b` model has a corresponding :func:`torchtune.models.llama3.qlora_llama3_8b`.
We aim to provide a comprehensive set of configurations to allow you to get started with training with QLoRA quickly,
just specify any config with ``_qlora`` in its name, e.g:
just specify any config with ``_qlora`` in its name.

All the rest of the LoRA parameters remain the same for QLoRA - check out the section above on :ref:`LoRA <glossary_lora>`
to see how to configure these parameters.

To configure from the command line:

.. code-block:: bash
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"] \
model.lora_rank=32 \
model.lora_alpha=64
All the rest of the LoRA parameters remain the same for QLoRA - check out the section above on :ref:`LoRA <glossary_lora>`
to see how to configure.
or, by modifying a config:

.. code-block:: yaml
model:
_component_: torchtune.models.qlora_llama3_8b
apply_lora_to_mlp: True
lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
lora_rank: 32
lora_alpha: 64
.. _glossary_dora:

Weight-Decomposed Low-Rank Adaptation (DoRA)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

*What's going on here?*

`DoRA <https://arxiv.org/abs/2402.09353>`_ is another PEFT technique which builds on-top of LoRA by
further decomposing the pre-trained weights into two components: magnitude and direction. The magnitude component
is a scalar vector that adjusts the scale, while the direction component corresponds to the original LoRA decomposition and
updates the orientation of weights.

DoRA adds a small overhead to LoRA training due to the addition of the magnitude parameter, but it has been shown to
improve the performance of LoRA, particularly at low ranks.

*Sounds great! How do I use it?*

Much like LoRA and QLoRA, you can finetune using DoRA with any of our LoRA recipes. We use the same model builders for LoRA
as we do for DoRA, so you can use the ``lora_`` version of any model builder with ``use_dora=True``. For example, to finetune
:func:`torchtune.models.llama3.llama3_8b` with DoRA, you would use :func:`torchtune.models.llama3.lora_llama3_8b` with ``use_dora=True``:

.. code-block:: bash
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.use_dora=True
.. code-block:: yaml
model:
_component_: torchtune.models.lora_llama3_8b
use_dora: True
Since DoRA extends LoRA, the parameters for :ref:`customizing LoRA <glossary_lora>` are identical. You can also quantize the base model weights like in :ref:`glossary_qlora` by using ``quantize=True`` to reap
even more memory savings!

.. code-block:: bash
tune run lora_finetune_single_device --config llama3/8B_lora_single_device \
model.apply_lora_to_mlp=True \
model.lora_attn_modules=["q_proj","k_proj","v_proj"] \
model.lora_rank=16 \
model.lora_alpha=32 \
model.use_dora=True \
model.quantize_base=True
.. code-block:: yaml
model:
_component_: torchtune.models.lora_llama3_8b
apply_lora_to_mlp: True
lora_attn_modules: ["q_proj", "k_proj", "v_proj"]
lora_rank: 16
lora_alpha: 32
use_dora: True
quantize_base: True
.. note::

Under the hood, we've enabled DoRA by adding the :class:`~torchtune.modules.peft.DoRALinear` module, which we swap
out for :class:`~torchtune.modules.peft.LoRALinear` when ``use_dora=True``.

.. _glossary_distrib:


.. TODO
.. Distributed
Expand Down
17 changes: 8 additions & 9 deletions torchtune/modules/peft/dora.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@


class DoRALinear(nn.Module, AdapterModule):
"""LoRA linear layer as introduced in `LoRA: Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2106.09685>`_.
LoRA perturbs a given layer via a low-rank approximation where only
the rank decomposition matrices are trainable. In a linear layer instead of
:math:`x \\mapsto W_0x` a LoRALinear layer is defined as
:math:`x \\mapsto W_0x + (\\alpha / r)BAx`, where :math:`r` is the rank of
the matrices :math:`A` and :math:`B` and :math:`\\alpha` is a scaling factor.
As in the original implementation, we support dropout before multiplication
by the low-rank matrices.
"""DoRA linear layer as introduced in
`DoRA: Weight-Decomposed Low-Rank Adaptation of Large Language Models <https://arxiv.org/abs/2402.09353>`_.
DoRA (Weight-Decomposed Low-Rank Adaptation) fine-tunes a layer by decomposing the pre-trained weights
into two components: magnitude and direction. The magnitude component is a learnable scalar vector
that scales each output channel, while the direction component, modified via LoRA, adjusts the orientation
of weights. By scaling the LoRA update component :math:`BAx` with the `magnitude` vector, DoRA allows the model
to apply distinct scaling adjustments across different output dimensions.
Args:
in_dim (int): input dimension
Expand Down

0 comments on commit 48a8449

Please sign in to comment.