Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement _compute_intra_grad_corr_mean for gradient computation #1095

Merged
merged 13 commits into from
Dec 5, 2022
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ install_dep_pytorch_stable: &install_dep_pytorch_stable
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.11 && exit 0; fi
# start installing
pip install --progress-bar off torch==1.12.0+cu113 torchvision==0.13.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install --progress-bar off torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
cyugao marked this conversation as resolved.
Show resolved Hide resolved
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand All @@ -120,7 +120,7 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.12 && exit 0; fi
# start installing
pip install --pre torch==1.13.0.dev20220825+cu113 torchvision==0.14.0.dev20220825+cu113 --extra-index-url https://download.pytorch.org/whl/nightly/cu113
cyugao marked this conversation as resolved.
Show resolved Hide resolved
pip install --pre torch==1.13.0 torchvision==0.14.0 --extra-index-url https://download.pytorch.org/whl/nightly/cu113
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ on:

jobs:
pre-commit:
runs-on: ubuntu-latest
runs-on: ubuntu-20.04
strategy:
matrix:
# make sure python versions are consistent with those used in .circleci/config.yml
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ repos:
hooks:
- id: black

- repo: https://gitlab.com/pycqa/flake8
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
Expand Down
14 changes: 14 additions & 0 deletions fairscale/fair_dev/testing/golden_testing_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,17 @@
"expected_bias_grad": [1.0, 1.0],
},
]

corr_mean_test_data = [
{
"inputs": [
[[1.0, 0.0, 2.0], [2.0, 0.0, 1.0]],
[[0.0, 1.0, 2.0], [2.0, 1.0, 0]],
[[3.0, 1.0, 2.0], [2.0, 1.0, -1.0]],
],
"expected_grad": [[1.5, 0.0, 1.5], [1.0, 1.0, 1.0], [2.5, 1.0, 0.5]],
# expected pearson correlation of two micro-batches
"expected_corr": [0.5, -1.0, 0.327327],
"expected_cos_similarity": [float("nan"), 0.8165, 0.8433],
}
]
41 changes: 41 additions & 0 deletions fairscale/optim/adascale.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,44 @@ def _update_avg(self, name: str, value: np.ndarray, factor: float) -> None:
else:
self._state[name] = factor * self._state[name] + (1.0 - factor) * value

def _gather_flat_grad(self) -> torch.Tensor:
"""
Helper function for gathering all gradients into a single vector.
Duplicated from torch.optim.lbfgs.
"""
views = []
for param_group in self._optimizer.param_groups:
cyugao marked this conversation as resolved.
Show resolved Hide resolved
for p in param_group["params"]:
if p.grad is None:
view = p.new(p.numel()).zero_()
elif p.grad.is_sparse:
view = p.grad.to_dense().view(-1)
else:
view = p.grad.view(-1)
views.append(view)
return torch.cat(views, 0)

def _compute_intra_grad_corr_mean(self) -> float:
"""
Helper function for computing average intra correlation among gradients on different GPUs.
cyugao marked this conversation as resolved.
Show resolved Hide resolved
"""
assert self._world_size > 1, "Only for distributed training"
flat_grad = self._gather_flat_grad()
corr_mean = torch.tensor(0.0).cuda()
if dist.get_rank() == 0:
size = flat_grad.numel()
gathered_tensors = [torch.zeros(size, device=0) for _ in range(self._world_size)]
dist.gather(flat_grad, gather_list=gathered_tensors, dst=0)
# the following requires torch 1.10+
corr = torch.stack(gathered_tensors).corrcoef() # type: ignore
# pick out the upper triangular part of the correlation matrix
corr = corr[torch.triu(torch.ones_like(corr), diagonal=1) == 1]
corr_mean = corr.mean()
else:
dist.gather(flat_grad, gather_list=None, dst=0)
dist.broadcast(corr_mean, src=0)
return corr_mean.item()
cyugao marked this conversation as resolved.
Show resolved Hide resolved

def _backward_hook(self, pg_idx: int, grad: torch.Tensor) -> None:
# This method should be invoked once for each parameter during the
# backward pass, before gradients are synchronized between world_size.
Expand Down Expand Up @@ -449,6 +487,9 @@ def _final_callback(self) -> None:
return

# Since self._local_grad_sqr is FP32, sum shouldn't overflow.

# TODO: Hongbo says param.grad might be FP16 should do this before converting to FP32.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is a bit too sparse. Maybe you can expand this and provide more context?


# This vector has length of # of param_groups, so it is small, but we
# use async to hide the all_reduce latency, esp when # of nodes is large.
work = None
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pynvml == 8.0.4
numpy == 1.22.0

# For layerwise gradient scaler
sklearn >= 0.0
scikit-learn >= 0.0
cyugao marked this conversation as resolved.
Show resolved Hide resolved

# For weigit. These are actually user requirements, not developer requirements.
# However, due to the experimental nature of weigit, we don't expose to the
Expand Down
52 changes: 52 additions & 0 deletions tests/optim/test_ddp_adascale.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,3 +152,55 @@ def test_grad_accum():
temp_file_name = tempfile.mkstemp()[1]

mp.spawn(_test_grad_accum_func, args=(world_size, temp_file_name), nprocs=world_size, join=True)


def _test_corr_mean_func(rank, world_size, tempfile_name, test_case):
_dist_init(rank, world_size, tempfile_name, backend="gloo") # Covers gloo

model = Linear(3, 1, bias=False)
model.to("cuda")
model = DDP(model, device_ids=[rank])
optim = AdaScale(SGD(model.parameters(), lr=0.1))
results = []
last_grad = None
for i, in_data in enumerate(test_case["inputs"]):
# use no_sync so we can access nonreduced gradients
with model.no_sync():
in_data = Tensor(in_data[rank]).cuda()
out = model(in_data)
out.sum().backward()
results.append(optim._compute_intra_grad_corr_mean())
# sync gradients manually
for p in model.parameters():
if p.grad is not None:
dist.all_reduce(p.grad, op=dist.ReduceOp.SUM)
# divide by world size
p.grad.data.div_(world_size)
grad = optim._gather_flat_grad()
assert np.allclose(grad.cpu(), test_case["expected_grad"][i])
optim.step()
if last_grad is not None:
# compute cosine similarity
cos_similarity = torch.dot(grad, last_grad) / (grad.norm() * last_grad.norm())
np.allclose(cos_similarity.cpu(), test_case["expected_cos_similarity"][i])
last_grad = grad
optim.zero_grad()
assert np.allclose(results, test_case["expected_corr"]), results

dist.destroy_process_group()


@skip_if_single_gpu
def test_corr_mean():
"""
Test _compute_intra_grad_corr_mean and _gather_flat_grad using ddp.no_sync()
We also demonstrate how cosine similarity between consecutive gradients can be computed using _gather_flat_grad
"""
world_size = 2
temp_file_name = tempfile.mkstemp()[1]

from fairscale.fair_dev.testing.golden_testing_data import corr_mean_test_data

test_case = corr_mean_test_data[0]

mp.spawn(_test_corr_mean_func, args=(world_size, temp_file_name, test_case), nprocs=world_size, join=True)