Skip to content

Commit

Permalink
update document to discuss fsdp and ds plugins. minor fixes.
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianlim committed Apr 23, 2024
1 parent 0b8e97c commit 681c697
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 2 deletions.
10 changes: 9 additions & 1 deletion docs/source/concept_guides/fsdp_and_deepspeed.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ specific language governing permissions and limitations under the License.
rendered properly in your Markdown viewer.
-->

# Moving between FSDP And DeepSpeed (DRAFT)
# Moving between FSDP And DeepSpeed

🤗 Accelerate offers flexibilty of training frameworks, by integrating two extremely powerful tools for distributed training, namely [Pytorch FSDP](../usage_guides/fsdp.md) and [Microsoft DeepSpeed](../usage_guides/deepspeed.md). The aim of this tutorial is to draw parallels, as well as to outline potential differences, to empower the user to switch seamlessly between these two frameworks.

Expand Down Expand Up @@ -53,6 +53,14 @@ For detailed descriptions of the above, refer to [🤗 `Accelerate` launch docum

To access other DeepSpeed configurations, such as mixed precision settings,
you need to pass in a `--deepspeed_config_file`, see the [documentation](../usage_guides/deepspeed#deepspeed-config-file).

DeepSpeed can be also configured via [`DeepSpeedPlugin`], e.g., `DeepSpeedPlugin.zero_stage` is equivalent of `--zero_stage`, and `DeepSpeedPlugin.hf_ds_config` can be used to pass `--deepeed_config_file.`

</Tip>

<Tip>

FSDP can be also configured via [`FullyShardedDataParallelPlugin`], e.g., `FullyShardedDataParallelPlugin.sharding_strategy` is equivalent of `--fsdp_sharding_strategy`.

</Tip>

Expand Down
7 changes: 6 additions & 1 deletion src/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1490,7 +1490,7 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
):
# keep log of names_params that was upcasted
# NOTE: resorted to this because warnings.simplefilter("once") is somehow not working
name_param_log = (module.module.__class__.__name__, ",".join(module._flat_param._fqns))
name_param_log = (module.module.__class__.__name__, ", ".join(module._flat_param._fqns))
if name_param_log not in upcasted_log:
upcasted_log.append(name_param_log)

Expand All @@ -1509,6 +1509,11 @@ def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, e
f"Affects: {param_log}."
)

if len(upcasted_log) > 0:
warnings.warn(
"FSDP upcast of low precision paramters may affect precision of model checkpoints."
)

# if the previous and current models are same, delete the previous one
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
del self._models[-2]
Expand Down

0 comments on commit 681c697

Please sign in to comment.