Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Upcasting for FSDP in Mixed Precision. Add Concept Guide for FSPD and DeepSpeed. #2674

Merged
merged 15 commits into from
Apr 29, 2024
Merged
2 changes: 2 additions & 0 deletions docs/source/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
title: Executing and deferring jobs
- local: concept_guides/gradient_synchronization
title: Gradient synchronization
- local: concept_guides/fsdp_and_deepspeed
title: FSDP vs DeepSpeed
- local: concept_guides/low_precision_training
title: How training in low-precision environments is possible (FP8)
- local: concept_guides/training_tpu
Expand Down
178 changes: 178 additions & 0 deletions docs/source/concept_guides/fsdp_and_deepspeed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->

# Moving between FSDP And DeepSpeed

🤗 Accelerate offers flexibilty of training frameworks, by integrating two extremely powerful tools for distributed training, namely [Pytorch FSDP](../usage_guides/fsdp.md) and [Microsoft DeepSpeed](../usage_guides/deepspeed.md). The aim of this tutorial is to draw parallels, as well as to outline potential differences, to empower the user to switch seamlessly between these two frameworks.

<Tip>

To switch between the frameworks, we recommend launching code 🤗 `accelerate launch` passing in the correct config file with `--config_file`, or passing in the respective arguments directly for [FSDP and DeepSpeed](../package_reference/cli#accelerate-launch) .

Example 🤗 Accelerate configurations can be found here for [DeepSpeed](../usage_guides/deepspeed#accelerate-deepspeed-plugin) and [FSDP](../usage_guides/fsdp#how-it-works-out-of-the-box), or in the [example zoo under "Launch Configurations"](../usage_guides/explore)

</Tip>

<Tip warning={true}>

This tutorial is for single-node, multi-GPU, scenarios only.

</Tip>

## Configuring Functionalities

Model tensors are split into different GPUs in an attempt to scale up model sizes; this is termed *sharding* in FSDP, and *partitioning* in DeepSpeed. FSDP sharding and DeepSpeed ZeRO (partitioning) stages are configured by `--fsdp_sharding_strategy`, and `--zero_stage`, respectively. In particular, FSDP `FULL_SHARD` maps to DeepSpeed ZeRO stage `3`; see this [comprehensive mapping between FSDP sharding and DeepSpeed ZeRO settings](../usage_guides/fsdp#mapping-between-fsdp-sharding-strategies-and-deepspeed-zero-stages). The below table summarizes and groups similar settings:

Group | Framework | Configuration | Example | Restrictions (if any)
--|--|--|--|--
sharding / partitioning | FSDP<br>DeepSpeed | `--fsdp_sharding_strategy`<br>`--zero_stage` | `1` (`FULL_SHARD`) <br>`3` |
offload | FSDP<br>DeepSpeed | `--fsdp_offload_params`<br>`--offload_param_device`<br>`--offload_optimizer_device` | `true`<br>`cpu`<br>`cpu` | all or nothing <br><br>
model loading | FSDP<br>DeepSpeed | <span style="white-space:nowrap;">`--fsdp_cpu_ram_efficient_loading`</span><br>`--zero3_init_flag` | `true`<br>`true` | <br>only ZeRO 3
efficient checkpointing | FSDP<br>DeepSpeed | `--fsdp_state_dict_type`<br>`--zero3_save_16bit_model` | `SHARDED_STATE_DICT`<br>`true` | <br>only ZeRO 3
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
pipeline | FSDP<br><br>DeepSpeed | `--fsdp_forward_prefetch`<br>`--fsdp_backward_prefetch`<br>None | `true`<br>`BACKWARD_PRE` | <br><br>
model | FSDP<br><br>DeepSpeed | `--fsdp_auto_wrap_policy`<br><span style="white-space:nowrap;">`--fsdp_transformer_layer_cls_to_wrap`</span><br>None | `TRANSFORMER_BASED_WRAP`<br><Layer Class> |<br>Usually not needed <br>Transparent to user.
parameters summoning | FSDP<br>DeepSpeed | `--fsdp_use_orig_params`<br>None | `true` | required for `torch.compile`<br>Transparent to user
parameters syncing | FSDP<br>DeepSpeed | `--fsdp_sync_module_states`<br>None | `true` |
training | FSDP<br>DeepSpeed | None<br>`--gradient_accumulation_steps`<br>`--gradient_clipping` | <br>`auto`<br>`auto` | Transparent to user

For detailed descriptions of the above, refer to [🤗 `Accelerate` launch documentation](../package_reference/cli#accelerate-launch).

<Tip>

To access other DeepSpeed configurations, such as mixed precision settings,
you need to pass in a `--deepspeed_config_file`, see the [documentation](../usage_guides/deepspeed#deepspeed-config-file).

DeepSpeed can be also configured via [`DeepSpeedPlugin`], e.g., `DeepSpeedPlugin.zero_stage` is equivalent of `--zero_stage`, and `DeepSpeedPlugin.hf_ds_config` can be used to pass `--deepeed_config_file.`

</Tip>

<Tip>

FSDP can be also configured via [`FullyShardedDataParallelPlugin`], e.g., `FullyShardedDataParallelPlugin.sharding_strategy` is equivalent of `--fsdp_sharding_strategy`.

</Tip>

### Checkpointing

Do note that while FSDP can be configured via `--fsdp_state_dict_type` to save either full / sharded checkpoints.

<Tip>

For DeepSpeed Zero3, it is recommended to also pass a `--zero3_save_16bit_model true`, which conveniently consolidates the model to a single rank and saves; this is the FSDP equivalent of `fsdp_state_dict_type: FULL_STATE_DICT`.

</Tip>

### Offloading

FSDP only allows *all-or-nothing* offload (i.e., either offload parameters, gradients, and optimizer, or keep them all in GPU), but DeepSpeed can offload parameters and optimizer differently. Furthermore, DeepSpeed also supports [offloading to NVME](https://www.deepspeed.ai/docs/config-json/#parameter-offloading).

### Prefetching

FSDP allows two prefetching configurations `--fsdp_forward_prefetch` and `--fsdp_backward_prefetch` to improve overlap of comms / computation at a cost of extra memory, see [FSDP documentation](https://pytorch.org/docs/stable/fsdp.html).
For DeepSpeed, the prefetching is always on, and only certain hyperparams like `stage3_prefetch_bucket_size` [can be configured for Zero3](https://www.deepspeed.ai/docs/config-json/#parameter-offloading); 🤗 `accelerate` will set these hyperparams automatically.

<Tip>

For FSDP set `fsdp_backward_prefetch: BACKWARD_PRE` for improved throughputs if memory allows.

</Tip>

### Model Loading

While FSDP require an explicit `--fsdp_cpu_ram_efficient_loading true` to activate efficient model loading, 🤗 `transformers` will activate the similar feature whenever DeepSpeed Zero3 is used.

<Tip>

For FSDP, whenever setting `--fsdp_cpu_ram_efficient_loading true`, please also set `--fsdp_sync_module_states true`, otherwise the model will not load properly.

</Tip>

### Model

FSDP requires an explicit `--fsdp_auto_wrap_policy` for the algorithm to decide how to schedule the all-gather and reduce-scatter operations. But for DeepSpeed this is transparent to the user.

<Tip>

For FSDP, simply set `fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP`. With the latest [`transformers`] versions, we try our best to figure out the suitable `fsdp_transformer_layer_cls_to_wrap` for HF transformers models. However, if you get an error regarding it, please specify this.

</Tip>

### Parameters Summoning

FSDP requires an explicit `--fsdp_use_orig_params` flag if using `torch.compile`, see [the pytorch documenation](https://pytorch.org/docs/stable/fsdp.html#module-torch.distributed.fsdp). For DeepSpeed this is transparent to the user.

<Tip>

For FSDP, when using `torch.compile` please set `fsdp_use_orig_params: True`.

</Tip>


## Training

Deepspeed requires explicit `--gradient_accumulation_steps` and `--gradient_clipping` flags. For FSDP this is transparent to the user.

<Tip>

When using DeepSpeed, set `gradient_accumulation_steps: "auto"` and `gradient_clipping: "auto"` to automatically pick up values set in the [`Accelerator`] or [`TrainingArguments`] (if using `transformers`).

</Tip>


## On Differences in Data Precision Handling

To discuss the how data precision is handled in both FSDP and Deepspeed, it is instructive to first give a flow of how model parameters are handled in these frameworks. Before the model / optimizer parameters are distributed across GPUs, parameter preparation is involved to first "flatten" them to one-dimensional [`torch.Tensor`](https://pytorch.org/docs/stable/tensors.html#torch-tensor)'s. The implementation of FSDP / DeepSpeed varies in the respect of the `dtype` in which these "flattened" parameters are stored, and there are ramifications with regards to how [`torch.Optimizer`](https://pytorch.org/docs/stable/optim.html#module-torch.optim)'s allocate their `dtypes`'s. The table below outlines the processes for both frameworks; the "Local" column indicates the process occurring at a per-gpu level, therefore any memory overheads by upcasting should be understood to be amortized by the number of gpus used.
fabianlim marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

@stas00 stas00 Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • I don't know what "give a flow" means - did you mean "give an overview"?
  • there should be no 's but s in #torch-tensor)'s. above
  • same with [torch.Optimizer](https://pytorch.org/docs/stable/optim.html#module-torch.optim)’s
  • same with dtype's


<Tip>

As a rule of thumb, for stable training with automatic mixed precision, all the trainable parameters have to be in `torch.float32`.

</Tip>

Process | Local | Framework | Details
--|--|--|--
Loading, i.e., [`AutoModel.from_pretrained(..., torch_dtype)`] |
Copy link
Contributor

@stas00 stas00 Apr 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably would be more clear as:

`AutoModel.from_pretrained(..., torch_dtype=torch_dtype)

it's ambiguous when it's passed as a positional argument, since I don't think it is even supported.

Preparation, i.e., creation of "flat params" | ✅ | FSDP<br>DeepSpeed | created in `torch_dtype`.<br> disregards `torch_dtype`, created in `float32`.
Optimizer initialization | ✅ | FSDP<br>DeepSpeed | creates parameters in `torch_dtype`<br> creates parameters in `float32`
Training Step, i.e, forward, backward, reduction | | FSDP<br>DeepSpeed | follows [`MixedPrecision`](https://pytorch.org/docs/stable/fsdp.html#torch.distributed.fsdp.MixedPrecision)<br> follows `deepspeed_config_file` mixed precision settings.
Optimizer (Pre-Step) | ✅ | FSDP<br>DeepSpeed | upcasting (if any) to `torch_dtype`<br>upcasted to `float32`
Optimizer (Actual Step) | ✅ | FSDP<br>DeepSpeed | occurs in `torch_dtype` <br> occurs in `float32`.

<Tip warning={true}>

Therefore when using DeepSpeed a small number of GPUs, be aware of potentially significant memory overheads due to the upcasting during preperation.

</Tip>

<Tip>

With FSDP, in the absence of mixed precision, it is possible to operate the [`torch.Optimizer`](https://pytorch.org/docs/stable/optim.html#module-torch.optim) in low precision `torch_dtype`, which may be helpful when using small number of GPUs.

</Tip>

<Tip warning={true}>

With mixed precision, FSDP and DeepSpeed will upcast in the model preparation step (c.f. table above). But do note that FSDP will then save checkpoints in the upcasted precision; Deepspeed may still save low precision checkpoints if `--zero3_save_16bit_model` is specified.

</Tip>


To clarify the above table consider the concrete examples below; the optimizer pre- and actual step combined for brevity. With FSDP it is possible to operate in the two modes shown below, but DeepSpeed can only operate in one.

Framework | Model Loading (`torch_dtype`) | Mixed Precision | Preparation (Local) | Training | Optimizer (Local)
--|--|--|--|--|--
FSDP | bf16 | default (none) | bf16 | bf16 | bf16
FSDP | bf16 | bf16 | fp32 | bf16 | fp32
DeepSpeed | bf16 | bf16 | fp32 | bf16 | fp32
4 changes: 4 additions & 0 deletions docs/source/package_reference/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,10 @@ The following arguments are only useful when `use_fsdp` is passed or Fully Shard
* `--fsdp_transformer_layer_cls_to_wrap` (`str`) -- Transformer layer class name (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`, `T5Block` ...
* `--fsdp_backward_prefetch_policy` (`str`) -- FSDP's backward prefetch policy.
* `--fsdp_state_dict_type` (`str`) -- FSDP's state dict type.
* `--fsdp_forward_prefetch` (`str`) -- FSDP forward prefetch.
* `--fsdp_use_orig_params` (`str`) -- If True, allows non-uniform `requires_grad` mixed in a FSDP unit.
* `--fsdp_cpu_ram_efficient_loading` (`str`) - If true, only the first process loads the pretrained model checkoint while all other processes have empty weights. When using this, `--fsdp_sync_module_states` needs to True.
* `--fsdp_sync_module_states` (`str`) - If true, each individually wrapped FSDP unit will broadcast module parameters from rank 0.

**Megatron-LM Arguments**:

Expand Down
67 changes: 67 additions & 0 deletions src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1447,6 +1447,73 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
),
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
)

# In the event the model had been loaded in low precision, but
# mixed precision had also been activated, then we follow DeepSpeed's
# strategy to hold the parameters in full precision.
# - assume that trainer.args.bf16 and trainer.args.fp16 are already checked against
# fsdp_plugin.mixed_precision_policy.
# - NOTE: we do not check the mixed_precision attribute on the FSDP root wrapper.
# * this attribute will always set by init_utils.init_core_state so its always not None.
# * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype
# * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None,
# we sill want to upcast the flat_param.
if self.mixed_precision != "no": # if mixed precision is set
upcasted_log = []
for module in FSDP.fsdp_modules(model):
# Referencing DeepSpeed Zero3
# - in Init, params are converted to 16bit while partitioning.
# - in accelerator.prepare, deepspeed.initalize is called to:
# * creates the DeepSpeeedEngine.
# * since zero_optimization() is True , calls engine._configure_zero_optimizer.
#
# Inside the DeepSpeed Zero3 optimizer configuration, which initalizes
# DeepSpeedZeroOptimizer_Stage3, during which:
# * trainable_param_groups are obtained from the attached optimizer
# (already partitioned in 16bit).
# * then _setup_for_real_optimizer -> _create_fp32_partitions
# which performs the fp32 upcasting.

# To mimick DeepSeepds's casting in FSDP, we look at the (single) FlatParameter held
# within an FSDP wrapper. This FlatParameter will be seen by the optimizer.
# - even though there is a torch.device('meta') guard below, we
# expect _init_utils._init_param_handle_from_module to already
# sync the parameter.

if not module._has_params:
continue # skip if FSDP module not managing parameters
param = module._flat_param
if (
param.dtype != torch.float32
and param.device != torch.device("meta")
and param.requires_grad
):
# keep log of names_params that was upcasted
# NOTE: resorted to this because warnings.simplefilter("once") is somehow not working
name_param_log = (module.module.__class__.__name__, ", ".join(module._flat_param._fqns))
if name_param_log not in upcasted_log:
upcasted_log.append(name_param_log)

# this works because of FSDP's _runtime_utils.lazy_init.
# Have to be careful not to call anything before this that
# triggers lazy_init (e.g., _is_fsdp_root).
param.data = param.data.to(torch.float32) # upcasting
module._handle._orig_param_dtype = torch.float32 # update
Comment on lines +1500 to +1501
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fabianlim, it looks like this introduced a breakage in FSDP integration - one of our tests now fails because of the last 2 lines, but works fine prior to this PR being merged.

    output = self.model(**batch)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 843, in forward
    args, kwargs = _pre_forward(
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 380, in _pre_forward
    unshard_fn(state, handle)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 415, in _pre_forward_unshard
    _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_runtime_utils.py", line 288, in _unshard
    ran_pre_unshard = handle.pre_unshard()
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 1252, in pre_unshard
    ret = self._writeback_orig_params()
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2249, in _writeback_orig_params
    self._writeback_tensor(
  File "/env/lib/conda/mgor-core-dev/lib/python3.10/site-packages/torch/distributed/fsdp/_flat_param.py", line 2348, in _writeback_tensor
    raise RuntimeError(
RuntimeError: Cannot writeback when the parameter shape changes
Expects torch.Size([48000]) but got torch.Size([3000, 16])

the total size (numel) is correct, but the shape is wrong

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also cc: @muellerzr

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the training step was just a trivial tiny llama2 model, like: https://huggingface.co/stas/tiny-random-llama-2

and the config was:

distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_forward_prefetch: false  # can flip to true once requiring torch>=2.2.0
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_cpu_ram_efficient_loading: true
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00 can you share the accelerate config ? This looks like something commonly seen if NOSHaRD is used.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

full config:

compute_environment: LOCAL_MACHINE

distributed_type: FSDP
fsdp_config:
  fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
  fsdp_forward_prefetch: false  # can flip to true once requiring torch>=2.2.0
  fsdp_backward_prefetch_policy: BACKWARD_PRE
  fsdp_offload_params: false
  fsdp_sharding_strategy: 1
  fsdp_state_dict_type: SHARDED_STATE_DICT
  fsdp_cpu_ram_efficient_loading: true
  fsdp_sync_module_states: true
  fsdp_use_orig_params: true

mixed_precision: bf16
machine_rank: 0
main_process_ip: null
main_process_port: null
main_training_function: main
num_machines: 1
num_processes: {num_processes}
use_cpu: false

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok I think it’s because of use-orig-params being set to true. This was a setting that I didn’t test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason to do the upcasting so late in the game? Won't it be much simpler to do it before the model gets wrapped by FSDP?

Accelerate knows by this time everything it needs to know to make the right thing, no?

Copy link
Contributor Author

@fabianlim fabianlim Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is because we only wanted to cast the sharded params and not the whole model. Casting the whole model is not a good idea. DeepSpeed does the same thing and only casts the sharded params

Copy link
Contributor

@stas00 stas00 Jun 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't we say that this was a convenience workaround and that the model should have been loaded in fp32 in the first place for fsdp to converge well?

Frankly, when do you not shard most of the model parameters? and even if selective typically it's the big params that are sharded so 2x of that would be close to all of model's params doubled in size.

Copy link
Contributor Author

@fabianlim fabianlim Jun 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@stas00
Essentially what I am saying is that upcasting ahead is will in general incur a 2X overhead in CPU memory in most general usage patterns that overhead will be there regardless if we use low_cpu_mem mode or not
image

BTW im having a bit of trouble reproducing your error on torch 2.1.2 and 2.2.0.

  • I took the learning_rate_repro.py script in the main issue
  • I ran the following command, where dataset.json is prepared following this
  • accelerate_stas.yaml is the config that was shared above.
  • using the same stas/tiny-random-llama-2
accelerate launch \
    --num_processes 4 \
    --main_process_port $MASTER_PORT \
    --config_file accelerate_stas.yaml \
    learning_rate_repro.py  \
      --model_name stas/tiny-random-llama-2 \
      --max_seq_len 4096 \
      --num_train_epochs 1 \
      --output_dir './results' \
      --per_device_train_batch_size 8 \
      --gradient_accumulation_steps 1 \
      --include_tokens_per_second \
      --gradient_checkpointing \
      --lr_scheduler_type "linear" \
      --learning_rate 1e-6 \
      --logging_steps 1 \
	  --packing True \
       --dataset_for_packing dataset.json \
      --dataset_text_field output \
      --bf16 \
      --max_steps 100 \
      --save_strategy 'no'

and the training runs.

{'loss': 8.0075, 'grad_norm': 0.9171388149261475, 'learning_rate': 9.9e-06, 'epoch': 1.0}
{'loss': 7.9684, 'grad_norm': 1.1230570077896118, 'learning_rate': 9.800000000000001e-06, 'epoch': 2.0}
{'loss': 7.8988, 'grad_norm': 1.232378602027893, 'learning_rate': 9.7e-06, 'epoch': 3.0}
{'loss': 7.8119, 'grad_norm': 1.1301443576812744, 'learning_rate': 9.600000000000001e-06, 'epoch': 4.0}
{'loss': 7.7319, 'grad_norm': 1.0580438375473022, 'learning_rate': 9.5e-06, 'epoch': 5.0}


# report the warnings
# some messages can be quite repetitive, especially when reporting about layers that have identical architecture.
if self.is_main_process:
for name_log, param_log in upcasted_log:
warnings.warn(
f"Upcasted low precision parameters in {name_log} because mixed precision turned on in FSDP. "
f"Affects: {param_log}."
)

if len(upcasted_log) > 0:
warnings.warn(
"FSDP upcast of low precision parameters may affect the precision of model checkpoints."
)

# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
Expand Down
Loading