Skip to content

Commit

Permalink
Implement _compute_intra_grad_corr_mean for gradient computation (#1095)
Browse files Browse the repository at this point in the history
* Fix gradient accumulation

Add ``is_scaled_loss`` flag to support both scaled / unscaled loss
Fix ``test_grad_accum`` and``test_set_num_gradients_to_accumulate``

* Add a method to scale grad for grad_accum using unscaled loss

- Revert the changes in `step` method
- Add a method `scale_grad_by_num_grads_to_accum`to handle gradient accumulation using unscaled loss more explicitly
- Add gradient tests

* Implement _compute_corr_mean_between_grads

* Improve tests and comments

* Use ubuntu-20.04 instead of latest

Use ubuntu-20.04 to fix the `arch x64 not found` issue
[Version 3.10 with arch x64 not found actions/setup-python#401](actions/setup-python#401)

* Switch flake8 from gitlab to github

Flake8 was moved to Github
See discussions https://www.reddit.com/r/Python/comments/yvfww8/flake8_took_down_the_gitlab_repository_in_favor/

* Fix scikit-learn package

* Update PyTorch versions

* Resolve comments from Min

* Minor fix

* Disable broken tests for new versions of PyTorch
  • Loading branch information
cyugao authored Dec 5, 2022
1 parent ee647b9 commit 99edeff
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 11 deletions.
14 changes: 7 additions & 7 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -100,16 +100,16 @@ install_dep_pytorch_lts: &install_dep_pytorch_lts
# most recent stable version
install_dep_pytorch_stable: &install_dep_pytorch_stable
- run:
name: Install Dependencies with torch 1.12.0
name: Install Dependencies with torch 1.13.0
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.11 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.13 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.12.0+cu113 torchvision==0.13.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install --progress-bar off torch==1.13.0 torchvision==0.14.0 --extra-index-url https://download.pytorch.org/whl/cu113
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "12"], f"wrong torch version {torch.__version__}"'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "13"], f"wrong torch version {torch.__version__}"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
Expand All @@ -118,13 +118,13 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
name: Install Dependencies with a torch nightly preview build
command: |
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.12 && exit 0; fi
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.14 && exit 0; fi
# start installing
pip install --pre torch==1.13.0.dev20220825+cu113 torchvision==0.14.0.dev20220825+cu113 --extra-index-url https://download.pytorch.org/whl/nightly/cu113
pip install --pre torch==1.14.0.dev20221121+cu117 torchvision==0.15.0.dev20221121+cu117 --extra-index-url https://download.pytorch.org/whl/nightly/cu117
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "13"], f"wrong torch version {torch.__version__}"'
python -c 'import torch; assert torch.__version__.split(".")[:2] == ["1", "14"], f"wrong torch version {torch.__version__}"'
python -m torch.utils.collect_env
wget -O /home/circleci/venv/check_version.py https://raw.githubusercontent.com/min-xu-ai/check_verion/main/check_version.py
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
pre-commit:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
strategy:
matrix:
# make sure python versions are consistent with those used in .circleci/config.yml
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
hooks:
- id: black

- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
Expand Down
14 changes: 14 additions & 0 deletions fairscale/fair_dev/testing/golden_testing_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,17 @@
"expected_bias_grad": [1.0, 1.0],
},
]

corr_mean_test_data = [
{
"inputs": [
[[1.0, 0.0, 2.0], [2.0, 0.0, 1.0]],
[[0.0, 1.0, 2.0], [2.0, 1.0, 0]],
[[3.0, 1.0, 2.0], [2.0, 1.0, -1.0]],
],
"expected_grad": [[1.5, 0.0, 1.5], [1.0, 1.0, 1.0], [2.5, 1.0, 0.5]],
# expected pearson correlation of two micro-batches
"expected_corr": [0.5, -1.0, 0.327327],
"expected_cos_similarity": [float("nan"), 0.8165, 0.8433],
}
]
44 changes: 44 additions & 0 deletions fairscale/optim/adascale.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,49 @@ def _update_avg(self, name: str, value: np.ndarray, factor: float) -> None:
else:
self._state[name] = factor * self._state[name] + (1.0 - factor) * value

def _gather_flat_grad(self) -> torch.Tensor:
"""
Helper function for gathering all gradients into a single vector.
Duplicated from torch.optim.lbfgs.
"""

def _to_flat_view(p: torch.Tensor) -> torch.Tensor:
"""
Local helper function for _gather_flat_grad.
Returns a flattened view of the input tensor.
"""
if p.grad is None:
return p.new(p.numel()).zero_() # type: ignore
elif p.grad.is_sparse: # type: ignore
return p.grad.to_dense().view(-1)
else:
return p.grad.view(-1)

views = [_to_flat_view(p) for param_group in self._optimizer.param_groups for p in param_group["params"]]
return torch.cat(views, 0)

def _compute_intra_grad_corr_mean(self) -> torch.Tensor:
"""
Helper function for computing average intra correlation among gradients on different GPUs.
This should be called under `model.no_sync()` context.
"""
assert self._world_size > 1, "Only for distributed training"
flat_grad = self._gather_flat_grad()
corr_mean = torch.tensor(0.0).cuda()
if dist.get_rank() == 0:
size = flat_grad.numel()
gathered_tensors = [torch.zeros(size, device=0) for _ in range(self._world_size)]
dist.gather(flat_grad, gather_list=gathered_tensors, dst=0)
# the following requires torch 1.10+
corr = torch.stack(gathered_tensors).corrcoef() # type: ignore
# pick out the upper triangular part of the correlation matrix
corr = corr[torch.triu(torch.ones_like(corr), diagonal=1) == 1]
corr_mean = corr.mean()
else:
dist.gather(flat_grad, gather_list=None, dst=0)
dist.broadcast(corr_mean, src=0)
return corr_mean

def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None:
# This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size.
Expand Down Expand Up @@ -449,6 +492,7 @@ def _final_callback(self) -> None:
return

# Since self._local_grad_sqr is FP32, sum shouldn't overflow.

# This vector has length of # of param_groups, so it is small, but we
# use async to hide the all_reduce latency, esp when # of nodes is large.
work = None
Expand Down
3 changes: 1 addition & 2 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,10 @@ pynvml == 8.0.4

# For mypy typing. It is important to have a fixed version. Otherwise, you
# may run into mypy errors out differently for different versions.
# Using 1.21.5 for now because py3.7 only has up to 1.21.5, not 1.22.x.
numpy == 1.22.0

# For layerwise gradient scaler
sklearn >= 0.0
scikit-learn == 1.1.3

# For weigit. These are actually user requirements, not developer requirements.
# However, due to the experimental nature of weigit, we don't expose to the
Expand Down
4 changes: 4 additions & 0 deletions tests/nn/checkpoint/test_checkpoint_activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ def forward(self, x):


@pytest.mark.parametrize("device", ["cpu", "cuda"])
@pytest.mark.skipif(
torch_version() >= (1, 13, 0),
reason="mem_peak behavior changed for torch 1.13 and above",
)
def test_basic(device):
if "cuda" in device and not torch.cuda.is_available():
pytest.skip("test requires a GPU")
Expand Down
4 changes: 4 additions & 0 deletions tests/nn/data_parallel/test_fsdp_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def cmp(results, expected):
@pytest.mark.timeout(120)
@pytest.mark.parametrize("ckpt", ["no_ckpt", "ckpt"])
@pytest.mark.parametrize("fsdp", ["ddp", "fsdp", "fsdp_amp_default", "fsdp_amp_compute_dtype32"])
@pytest.mark.skipif(
torch_version() >= (1, 14, 0),
reason="Tests broke in Pytorch pre-release version 1.14",
)
def test_fsdp_memory(fsdp, ckpt):
expected = {
("ddp", "no_ckpt"): {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,10 @@ def _get_cached_results(
@pytest.mark.parametrize("wrap_bn", ["auto_wrap_bn", "no_auto_wrap_bn"])
@pytest.mark.parametrize("model_type", ["model1", "model2"])
@pytest.mark.parametrize("bn_type", ["bn", "sync_bn"])
@pytest.mark.skipif(
torch_version() >= (1, 14, 0),
reason="Tests broke in Pytorch pre-release version 1.14",
)
def test_multiple_forward_checkpoint(precision, flatten, wrap_bn, model_type, bn_type):
mixed_precision = precision == "mixed"
flatten = flatten == "flatten"
Expand Down
57 changes: 57 additions & 0 deletions tests/optim/test_ddp_adascale.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@

from fairscale.fair_dev.testing.golden_testing_data import adascale_test_data
from fairscale.fair_dev.testing.testing import skip_if_single_gpu
from fairscale.internal import torch_version
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
from fairscale.nn.data_parallel import ShardedDataParallel as SDP
from fairscale.optim import OSS, AdaScale
Expand Down Expand Up @@ -152,3 +153,59 @@ def test_grad_accum():
temp_file_name = tempfile.mkstemp()[1]

mp.spawn(_test_grad_accum_func, args=(world_size, temp_file_name), nprocs=world_size, join=True)


def _test_corr_mean_func(rank, world_size, tempfile_name, test_case):
_dist_init(rank, world_size, tempfile_name, backend="gloo") # Covers gloo

model = Linear(3, 1, bias=False)
model.to("cuda")
model = DDP(model, device_ids=[rank])
optim = AdaScale(SGD(model.parameters(), lr=0.1))
results = []
last_grad = None
for i, in_data in enumerate(test_case["inputs"]):
# use no_sync so we can access nonreduced gradients
with model.no_sync():
in_data = Tensor(in_data[rank]).cuda()
out = model(in_data)
out.sum().backward()
results.append(optim._compute_intra_grad_corr_mean().item())
# sync gradients manually
for p in model.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
# divide by world size
p.grad.data.div_(world_size)
grad = optim._gather_flat_grad()
assert np.allclose(grad.cpu(), test_case["expected_grad"][i])
optim.step()
if last_grad is not None:
# compute cosine similarity
cos_similarity = torch.dot(grad, last_grad) / (grad.norm() * last_grad.norm())
np.allclose(cos_similarity.cpu(), test_case["expected_cos_similarity"][i])
last_grad = grad
optim.zero_grad()
assert np.allclose(results, test_case["expected_corr"]), results

dist.destroy_process_group()


@skip_if_single_gpu
@pytest.mark.skipif(
torch_version() < (1, 10, 0),
reason="torch.corrcoef available only for torch 1.10 or higher",
)
def test_corr_mean():
"""
Test _compute_intra_grad_corr_mean and _gather_flat_grad using ddp.no_sync()
We also demonstrate how cosine similarity between consecutive gradients can be computed using _gather_flat_grad
"""
world_size = 2
temp_file_name = tempfile.mkstemp()[1]

from fairscale.fair_dev.testing.golden_testing_data import corr_mean_test_data

test_case = corr_mean_test_data[0]

mp.spawn(_test_corr_mean_func, args=(world_size, temp_file_name, test_case), nprocs=world_size, join=True)

0 comments on commit 99edeff

Please sign in to comment.