diff --git a/mmengine/model/wrappers/fully_sharded_distributed.py b/mmengine/model/wrappers/fully_sharded_distributed.py index 1d05ecb947..87780b3bfe 100644 --- a/mmengine/model/wrappers/fully_sharded_distributed.py +++ b/mmengine/model/wrappers/fully_sharded_distributed.py @@ -12,7 +12,7 @@ from mmengine.structures import BaseDataElement # support customize fsdp policy -FSDP_WRAP_POLICYS = Registry('fsdp wrap policy') +FSDP_WRAP_POLICIES = Registry('fsdp wrap policy') @MODEL_WRAPPERS.register_module() @@ -60,7 +60,7 @@ class MMFullyShardedDataParallel(FullyShardedDataParallel): users' pre-defined config in MMEngine, its type is expected to be `None`, `str` or `Callable`. If it's `str`, then MMFullyShardedDataParallel will try to get specified method in - ``FSDP_WRAP_POLICYS`` registry,and this method will be passed to + ``FSDP_WRAP_POLICIES`` registry,and this method will be passed to FullyShardedDataParallel to finally initialize model. Note that this policy currently will only apply to child modules of @@ -122,10 +122,10 @@ def __init__( if fsdp_auto_wrap_policy is not None: if isinstance(fsdp_auto_wrap_policy, str): - assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICYS, \ - '`FSDP_WRAP_POLICYS` has no ' \ + assert fsdp_auto_wrap_policy in FSDP_WRAP_POLICIES, \ + '`FSDP_WRAP_POLICIES` has no ' \ f'function {fsdp_auto_wrap_policy}' - fsdp_auto_wrap_policy = FSDP_WRAP_POLICYS.get( # type: ignore + fsdp_auto_wrap_policy = FSDP_WRAP_POLICIES.get( # type: ignore fsdp_auto_wrap_policy) if not isinstance(fsdp_auto_wrap_policy, Callable): # type: ignore