Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for colossalai 0.1.11 #15888

Merged
merged 18 commits into from
Dec 20, 2022
63 changes: 47 additions & 16 deletions src/pytorch_lightning/strategies/colossalai.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import operator
from typing import Any, Callable, Dict, List, Mapping, Optional, TYPE_CHECKING, Union

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.imports import compare_version, RequirementCache
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.nn import Module
Expand Down Expand Up @@ -130,7 +132,7 @@ def __init__(
force_outputs_fp32: bool = False,
gpu_margin_mem_ratio: float = 0.0,
chunk_search_range: int = 64 * 1024**2,
chunk_search_n_grids: int = 1024,
chunk_search_n_grids: int = 4096,
min_chunk_size: Optional[int] = None,
initial_scale: float = 2**16,
min_scale: float = 1,
Expand Down Expand Up @@ -264,23 +266,52 @@ def setup_precision_plugin(self) -> None:
)
assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase))
pl_module = self.model
process_group = ProcessGroup()

new_version_flag = compare_version("colossalai", operator.ge, "0.1.10")

carmocca marked this conversation as resolved.
Show resolved Hide resolved
if not hasattr(pl_module, "_colossalai_zero"):
if self.use_chunk:
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
self.model, **self.chunk_size_search_kwargs
if not new_version_flag:
if self.use_chunk:
chunk_size = self.chunk_size or ChunkManager.search_chunk_size(
self.model, **self.chunk_size_search_kwargs
)
else:
chunk_size = None
process_group = ProcessGroup()
chunk_manager = ChunkManager(
chunk_size,
process_group,
self.enable_distributed_storage,
GeminiManager.get_default_device(self.placement_policy),
)
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
model = _LightningModuleWrapperBase(self.model)
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
else:
chunk_size = None
chunk_manager = ChunkManager(
chunk_size,
process_group,
self.enable_distributed_storage,
GeminiManager.get_default_device(self.placement_policy),
)
gemini_manager = GeminiManager(self.placement_policy, chunk_manager)
model = _LightningModuleWrapperBase(self.model)
self.model = ZeroDDP(model, gemini_manager, self.force_outputs_fp32)
with _patch_cuda_is_available():
from colossalai.nn.parallel import GeminiDDP
carmocca marked this conversation as resolved.
Show resolved Hide resolved
from colossalai.utils import get_current_device

assert self.use_chunk, "`ColossalAIStrategy` must use chunk in versions higher than 0.1.10"
1SAA marked this conversation as resolved.
Show resolved Hide resolved
chunk_search_range = self.chunk_size_search_kwargs["search_range"]
search_range_mb = self.chunk_size_search_kwargs["search_range"] / 1024**2
search_interval = math.ceil(chunk_search_range / self.chunk_size_search_kwargs["n_grids"])
min_chunk_size_mb = self.chunk_size_search_kwargs["min_chunk_size"]
if min_chunk_size_mb:
min_chunk_size_mb /= 1024**2

model = _LightningModuleWrapperBase(self.model)
self.model = GeminiDDP(
module=model,
device=get_current_device(),
placement_policy=self.placement_policy,
pin_memory=True,
force_outputs_fp32=self.force_outputs_fp32,
search_range_mb=search_range_mb,
hidden_dim=search_interval,
min_chunk_size_mb=min_chunk_size_mb,
)

assert self.model is not None
pl_module._colossalai_zero = [self.model] # type: ignore[assignment]
else:
Expand Down