Skip to content

Commit

Permalink
Add ColossalAI strategy (#14224)
Browse files Browse the repository at this point in the history
Co-authored-by: HELSON <[email protected]>
Co-authored-by: rohitgr7 <[email protected]>
Co-authored-by: otaj <[email protected]>
Co-authored-by: Carlos Mocholí <[email protected]>
  • Loading branch information
5 people authored Oct 11, 2022
1 parent 6f16e46 commit 2fef6d9
Show file tree
Hide file tree
Showing 16 changed files with 933 additions and 8 deletions.
6 changes: 6 additions & 0 deletions .azure/gpu-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ jobs:
set -e
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'horovod' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'bagua' not in line] ; open(fname, 'w').writelines(lines)"
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)"
PYTORCH_VERSION=$(python -c "import torch; print(torch.__version__.split('+')[0])")
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/base.txt ${PYTORCH_VERSION}
Expand All @@ -110,6 +111,11 @@ jobs:
CUDA_VERSION_BAGUA=$(python -c "print([ver for ver in [116,113,111,102] if $CUDA_VERSION_MM >= ver][0])")
pip install "bagua-cuda$CUDA_VERSION_BAGUA"
PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])")
CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))")
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])")
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org
pip list
env:
PACKAGE_NAME: pytorch
Expand Down
36 changes: 30 additions & 6 deletions dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ RUN \
libopenmpi-dev \
openmpi-bin \
ssh \
ninja-build \
libnccl2=$TO_INSTALL_NCCL \
libnccl-dev=$TO_INSTALL_NCCL && \
# Install python
# Install python
add-apt-repository ppa:deadsnakes/ppa && \
apt-get install -y \
python${PYTHON_VERSION} \
Expand All @@ -65,7 +66,7 @@ RUN \
&& \
update-alternatives --install /usr/bin/python${PYTHON_VERSION%%.*} python${PYTHON_VERSION%%.*} /usr/bin/python${PYTHON_VERSION} 1 && \
update-alternatives --install /usr/bin/python python /usr/bin/python${PYTHON_VERSION} 1 && \
# Cleaning
# Cleaning
apt-get autoremove -y && \
apt-get clean && \
rm -rf /root/.cache && \
Expand All @@ -82,14 +83,15 @@ RUN \
rm get-pip.py && \
pip install -q fire && \
# Disable cache \
CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
export CUDA_VERSION_MM=$(python -c "print(''.join('$CUDA_VERSION'.split('.')[:2]))") && \
pip config set global.cache-dir false && \
# set particular PyTorch version
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/base.txt ${PYTORCH_VERSION} && \
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/extra.txt ${PYTORCH_VERSION} && \
python ./requirements/pytorch/adjust-versions.py requirements/pytorch/examples.txt ${PYTORCH_VERSION} && \
# Install all requirements \
pip install -r requirements/pytorch/devel.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \

# Install base requirements \
pip install -r requirements/pytorch/base.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html && \
rm assistant.py

ENV \
Expand All @@ -108,7 +110,7 @@ RUN \
export HOROVOD_BUILD_CUDA_CC_LIST=${HOROVOD_BUILD_CUDA_CC_LIST//"."/""} && \
echo $HOROVOD_BUILD_CUDA_CC_LIST && \
cmake --version && \
pip install --no-cache-dir -r ./requirements/pytorch/strategies.txt && \
pip install --no-cache-dir horovod && \
horovodrun --check-build

RUN \
Expand Down Expand Up @@ -136,6 +138,28 @@ RUN \
if [[ "$CUDA_VERSION_MM" = "$CUDA_VERSION_BAGUA" ]]; then python -c "import bagua_core; bagua_core.install_deps()"; fi && \
python -c "import bagua; print(bagua.__version__)"

RUN \
# install ColossalAI
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
if [[ "$SHOULD_INSTALL_COLOSSAL" = "1" ]]; then \
PYTORCH_VERSION_COLOSSALAI=$(python -c "import torch; print(torch.__version__.split('+')[0][:4])") ; \
CUDA_VERSION_MM_COLOSSALAI=$(python -c "import torch ; print(''.join(map(str, torch.version.cuda)))") ; \
CUDA_VERSION_COLOSSALAI=$(python -c "print([ver for ver in [11.3, 11.1] if $CUDA_VERSION_MM_COLOSSALAI >= ver][0])") ; \
pip install "colossalai==0.1.10+torch${PYTORCH_VERSION_COLOSSALAI}cu${CUDA_VERSION_COLOSSALAI}" --find-links https://release.colossalai.org ; \
python -c "import colossalai; print(colossalai.__version__)" ; \
fi

RUN \
# install rest of strategies
# remove colossalai from requirements since they are installed separately
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
if [[ "$SHOULD_INSTALL_COLOSSAL" = "0" ]]; then \
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \
fi && \
echo "$SHOULD_INSTALL_COLOSSAL" && \
cat requirements/pytorch/strategies.txt && \
pip install -r requirements/pytorch/devel.txt -r requirements/pytorch/strategies.txt --no-cache-dir --find-links https://download.pytorch.org/whl/cu${CUDA_VERSION_MM}/torch_stable.html

COPY requirements/pytorch/check-avail-extras.py check-avail-extras.py
COPY requirements/pytorch/check-avail-strategies.py check-avail-strategies.py

Expand Down
6 changes: 5 additions & 1 deletion dockers/release/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@ RUN \
fi && \
# otherwise there is collision with folder name ans pkg name on Pypi
cd lightning && \
pip install .["extra","loggers","strategies"] --no-cache-dir && \
SHOULD_INSTALL_COLOSSAL=$(python -c "import torch; print(1 if int(torch.__version__.split('.')[1]) > 9 else 0)") && \
if [[ "$SHOULD_INSTALL_COLOSSAL" = "0" ]]; then \
python -c "fname = 'requirements/pytorch/strategies.txt' ; lines = [line for line in open(fname).readlines() if 'colossalai' not in line] ; open(fname, 'w').writelines(lines)" ; \
fi && \
pip install .["extra","loggers","strategies"] --no-cache-dir --find-links https://release.colossalai.org && \
cd .. && \
rm -rf lightning

Expand Down
4 changes: 3 additions & 1 deletion docs/source-pytorch/api_references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ precision
:template: classtemplate.rst

ApexMixedPrecisionPlugin
ColossalAIPrecisionPlugin
DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
Expand Down Expand Up @@ -285,7 +286,7 @@ strategies
:template: classtemplate.rst

BaguaStrategy
HivemindStrategy
ColossalAIStrategy
DDPFullyShardedNativeStrategy
DDPFullyShardedStrategy
DDPShardedStrategy
Expand All @@ -294,6 +295,7 @@ strategies
DDPStrategy
DataParallelStrategy
DeepSpeedStrategy
HivemindStrategy
HorovodStrategy
HPUParallelStrategy
IPUStrategy
Expand Down
1 change: 1 addition & 0 deletions docs/source-pytorch/extensions/plugins.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ The full list of built-in precision plugins is listed below.
:template: classtemplate.rst

ApexMixedPrecisionPlugin
ColossalAIPrecisionPlugin
DeepSpeedPrecisionPlugin
DoublePrecisionPlugin
FullyShardedNativeMixedPrecisionPlugin
Expand Down
3 changes: 3 additions & 0 deletions docs/source-pytorch/extensions/strategy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,9 @@ The below table lists all relevant strategies available in Lightning with their
* - collaborative
- :class:`~pytorch_lightning.strategies.HivemindStrategy`
- Strategy for training collaboratively on local machines or unreliable GPUs across the internet. :ref:`Learn more. <strategies/hivemind:Training on unreliable mixed GPUs across the internet>`
* - colossalai
- :class:`~pytorch_lightning.strategies.ColossalAIStrategy`
- Colossal-AI provides a collection of parallel components for you. It aims to support you to write your distributed deep learning models just like how you write your model on your laptop. `Learn more. <https://www.colossalai.or/>`__
* - fsdp_native
- :class:`~pytorch_lightning.strategies.DDPFullyShardedNativeStrategy`
- Strategy for Fully Sharded Data Parallel provided by PyTorch. :ref:`Learn more. <advanced/model_parallel:PyTorch Fully Sharded Training>`
Expand Down
1 change: 1 addition & 0 deletions requirements/pytorch/strategies.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# NOTE: the upper bound for the package version is only set for CI stability, and it is dropped while installing this package
# in case you want to preserve/enforce restrictions on the latest compatible version, add "strict" as an in-line comment

colossalai>=0.1.10
fairscale>=0.4.5, <=0.4.6
deepspeed>=0.6.0, <=0.7.0
# no need to install with [pytorch] as pytorch is already installed
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/plugins/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytorch_lightning.plugins.io.hpu_plugin import HPUCheckpointIO
from pytorch_lightning.plugins.layer_sync import LayerSync, NativeSyncBatchNorm
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.colossalai import ColossalAIPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
Expand All @@ -27,6 +28,7 @@
"XLACheckpointIO",
"HPUCheckpointIO",
"ApexMixedPrecisionPlugin",
"ColossalAIPrecisionPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"IPUPrecisionPlugin",
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_lightning/plugins/precision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pytorch_lightning.plugins.precision.apex_amp import ApexMixedPrecisionPlugin
from pytorch_lightning.plugins.precision.colossalai import ColossalAIPrecisionPlugin
from pytorch_lightning.plugins.precision.deepspeed import DeepSpeedPrecisionPlugin
from pytorch_lightning.plugins.precision.double import DoublePrecisionPlugin
from pytorch_lightning.plugins.precision.fsdp_native_native_amp import FullyShardedNativeNativeMixedPrecisionPlugin
Expand All @@ -26,6 +27,7 @@

__all__ = [
"ApexMixedPrecisionPlugin",
"ColossalAIPrecisionPlugin",
"DeepSpeedPrecisionPlugin",
"DoublePrecisionPlugin",
"FullyShardedNativeNativeMixedPrecisionPlugin",
Expand Down
90 changes: 90 additions & 0 deletions src/pytorch_lightning/plugins/precision/colossalai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Callable, Optional, Union

from lightning_utilities.core.rank_zero import WarningCache
from torch import Tensor
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.types import Steppable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities.enums import PrecisionType

warning_cache = WarningCache()


class ColossalAIPrecisionPlugin(PrecisionPlugin):
"""Precision plugin for ColossalAI integration.
Args:
precision: Half precision (16).
Raises:
ValueError:
If precison is not 16.
"""

def __init__(self, precision: Union[str, int] = 16) -> None:
if not (precision == PrecisionType.HALF):
raise ValueError(
f"`Trainer(strategy='colossalai', precision={precision!r})` is not supported."
" Consider setting `precision=16`."
)
super().__init__()
self.precision = precision

def backward( # type: ignore[override]
self,
tensor: Tensor,
model: "pl.LightningModule",
optimizer: Optional[Steppable],
optimizer_idx: Optional[int],
*args: Any,
**kwargs: Any,
) -> None:
assert optimizer is not None
optimizer.backward(tensor)

def clip_grad_by_norm(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
optimizer.clip_grad_norm(None, clip_val)

def clip_grad_by_value(self, optimizer: Optimizer, clip_val: Union[int, float]) -> None:
raise NotImplementedError("`clip_grad_by_value` is not supported by `ColossalAI`")

def optimizer_step( # type: ignore[override]
self,
optimizer: Steppable,
model: "pl.LightningModule",
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> Any:
closure_result = closure()
self._after_closure(model, optimizer, optimizer_idx)
skipped_backward = closure_result is None
if isinstance(model, pl.LightningModule) and model.automatic_optimization and skipped_backward:
raise ValueError(
"Skipping backward by returning `None` from your `training_step` is not supported by `ColossalAI`."
)
optimizer.step()

def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
return
# the gradients are not available in the model due to gradient partitioning in zero stage >= 2
warning_cache.warn(
f"You set `Trainer(track_grad_norm={trainer.track_grad_norm!r})' but this is not supported for ColossalAI."
" The setting will be ignored."
)
1 change: 1 addition & 0 deletions src/pytorch_lightning/strategies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
from lightning_lite.strategies.registry import _StrategyRegistry
from pytorch_lightning.strategies.bagua import BaguaStrategy # noqa: F401
from pytorch_lightning.strategies.colossalai import ColossalAIStrategy # noqa: F401
from pytorch_lightning.strategies.ddp import DDPStrategy # noqa: F401
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy # noqa: F401
from pytorch_lightning.strategies.deepspeed import DeepSpeedStrategy # noqa: F401
Expand Down
13 changes: 13 additions & 0 deletions src/pytorch_lightning/strategies/bagua.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,16 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from typing import Any, Dict, List, Optional, Union
Expand Down
Loading

0 comments on commit 2fef6d9

Please sign in to comment.