Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for colossalai 0.1.11 #15888

Merged
merged 18 commits into from
Dec 20, 2022
77 changes: 60 additions & 17 deletions src/pytorch_lightning/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import operator
from typing import Any, Callable, Dict, List, Mapping, Optional, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import compare_version, RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.nn import Module
Expand Down Expand Up @@ -130,7 +132,7 @@ def __init__(
force_outputs_fp32: bool = False,
gpu_margin_mem_ratio: float = 0.0,
chunk_search_range: int = 64 * 1024**2,
chunk_search_n_grids: int = 1024,
chunk_search_n_grids: int = 4096,
min_chunk_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
Expand Down Expand Up @@ -237,7 +239,8 @@ def _post_init_method(self, module: torch.nn.Module, *args: Any, **kwargs: Any)
if getattr(module, "_colossalai_module", False) is True:
return
super()._post_init_method(module, *args, **kwargs)
module._colossalai_module = True # type: ignore[assignment]
for sub_module in module.modules():
sub_module._colossalai_module = True # type: ignore[assignment]

return ModelShardedContext()

Expand All @@ -264,23 +267,54 @@ def setup_precision_plugin(self) -> None:
)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
pl_module = self.model
process_group = ProcessGroup()

new_version_flag = compare_version("colossalai", operator.gt, "0.1.10")

carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(pl_module, "_colossalai_zero"):
if self.use_chunk:
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
self.model, **self.chunk_size_search_kwargs
if not new_version_flag:
if self.use_chunk:
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
self.model, **self.chunk_size_search_kwargs
)
else:
chunk_size = None
process_group = ProcessGroup()
chunk_manager = ChunkManager(
chunk_size,
process_group,
self.enable_distributed_storage,
GeminiManager.get_default_device(self.placement_policy),
)
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
model = _LightningModuleWrapperBase(self.model)
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
else:
chunk_size = None
chunk_manager = ChunkManager(
chunk_size,
process_group,
self.enable_distributed_storage,
GeminiManager.get_default_device(self.placement_policy),
)
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
model = _LightningModuleWrapperBase(self.model)
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
with _patch_cuda_is_available():
from colossalai.nn.parallel import GeminiDDP
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from colossalai.utils import get_current_device
if not self.use_chunk:
raise MisconfigurationException(
"`ColossalAIStrategy` must use chunk in versions higher than 0.1.10"
)
chunk_search_range = self.chunk_size_search_kwargs["search_range"]
search_range_mb = self.chunk_size_search_kwargs["search_range"] / 1024**2
search_interval = math.ceil(chunk_search_range / self.chunk_size_search_kwargs["n_grids"])
min_chunk_size_mb = self.chunk_size_search_kwargs["min_chunk_size"]
if min_chunk_size_mb:
min_chunk_size_mb /= 1024**2

model = _LightningModuleWrapperBase(self.model)
self.model = GeminiDDP(
module=model,
device=get_current_device(),
placement_policy=self.placement_policy,
pin_memory=True,
force_outputs_fp32=self.force_outputs_fp32,
search_range_mb=search_range_mb,
hidden_dim=search_interval,
min_chunk_size_mb=min_chunk_size_mb,
)

assert self.model is not None
pl_module._colossalai_zero = [self.model] # type: ignore[assignment]
else:
Expand Down Expand Up @@ -329,10 +363,19 @@ def setup(self, trainer: "pl.Trainer") -> None:
self.accelerator.setup(trainer)
assert self.lightning_module is not None
self.lightning_module._device = self.root_device
self.ignore_no_grad_parameters(self.root_device)
self.setup_optimizers(trainer)
self.setup_precision_plugin()
self.model_to_device()

def ignore_no_grad_parameters(self, running_device) -> None:
# for those parameters with no gradients
# we shold ignore them on DDP and move them to CUDA
for param in self.model.parameters():
if not param.requires_grad:
param._ddp_to_ignore = True
param.data = param.data.to(running_device)

def model_to_device(self) -> None:
assert self.lightning_module is not None
pl_module = self.lightning_module
Expand Down