Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Report of increased memory overhead during cudagraph capture with nccl >= 2.19 #1234

Closed
youkaichao opened this issue Mar 24, 2024 · 15 comments
Closed

Comments

@youkaichao
Copy link

Hi, I would like to report a memory issue with nccl. A reproducible example is attached below:

In a gcp g2-standard-24 instance (with 2 L4 GPUs):

docker pull us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:a3c2340ae36ce8ee782691d30111377eaf7ae6ce
docker run --gpus all --shm-size=2g -it us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:a3c2340ae36ce8ee782691d30111377eaf7ae6ce -- /bin/bash

# inside docker
cd /vllm-workspace/tests/distributed
export NCCL_DEBUG=TRACE
TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf pytest -v -s --forked test_basic_distributed_correctness.py

Note that the code manually links against a pre-downloaded nccl==2.18.3. There is also a nccl==2.19.3 available inside the image, the path is /usr/local/lib/python3.10/dist-packages/nvidia/nccl/lib/libnccl.so.2 .

By adding the following line in /vllm-workspace/vllm/model_executor/parallel_utils/pynccl.py, before nccl = ctypes.CDLL(so_file):

so_file = "/usr/local/lib/python3.10/dist-packages/nvidia/nccl/lib/libnccl.so.2"

We can force the program to use nccl 2.19.3, and we will get an OOM error.

The background:

In distributed inference, https://github.com/vllm-project/vllm uses nccl together with cudagraph. We capture about 30 graphs with different batch sizes. The memory overhead when we use pytorch 2.1.2 (with nccl==2.18.3) is nearly zero (about 10MB per graph, and sometimes is zero); however, when we upgrade to pytorch 2.2.0 (with nccl==2.19.3), the memory overhead is more than 100MB per graph.

We spent more than a week (to be honest, more time than one would feel comfortable with) to investigate the issue. We used to think it should be related with pytorch, but finally we find the problem comes from the nccl library.

For more code on measuring the memory overhead, please check vllm-project/vllm#3442 (comment) .

It would be very helpful if the nccl team can point our the root cause of the memory overhead, and potential knobs to control it (e.g. via some environment variables). The above problem happens for both nccl==2.19.3 and nccl==2.20.5 .

Thank you for your time.

@sjeaugey
Copy link
Member

Thanks for the report. @jbachan did we change something in 2.19 regarding the way we store graph-related data for NCCL?

@youkaichao how many graphs do you expect to create? A power of ten is ok, it doesn't have to be precise; in my mind apps would create a couple of graphs (less than 10 for sure) so even 100MB per graph doesn't seem like a big problem. If the expectation is that we should allow for tens or hundred of graphs, it could influence how we fix the issue and design things in the future.

@youkaichao
Copy link
Author

how many graphs do you expect to create?

About 30, each for a different batch size (8, 16, 24, ..., 256).

Technically we capture the graph with large batchsize first, followed by smaller batchsize. We find cudagraph with smaller batchsize can share the memory buffer of larger batchsize. So even with more than 30 cudagraph, the memory overhead prior to nccl 2.19 is low.

@youkaichao
Copy link
Author

@sjeaugey hi, any update on this? 👀

@youkaichao
Copy link
Author

@sjeaugey hi, any update on this? 👀

@youkaichao
Copy link
Author

@sjeaugey we finally find this problem is related with the virtual memory usage, and solve it in vllm-project/vllm#5091 .

It would be great if you can share why virtual memory is allocated during graph capture, and what it is used for, what's the performance impact of turning it off 🙏

@sjeaugey
Copy link
Member

Thanks for the feedback. I'm not sure I understand how you solved it though. Was it a proper fix or a workaround?

From what I understand, you're saying that enabling CUMEM increases the memory usage when using CUDA graphs. Is that accurate? Did you confirm the problem went away with NCCL_CUMEM_ENABLE=0?

@youkaichao
Copy link
Author

Yes, the fix is to set NCCL_CUMEM_ENABLE=0 .

@youkaichao
Copy link
Author

I would say this is only a workaround. I don't know why NCCL costs more memory with cudagraph when NCCL_CUMEM_ENABLE is used by default in nccl 2.19 .

@youkaichao
Copy link
Author

here is a minimal reproducible example:

import torch
import torch.distributed as dist
from contextlib import contextmanager

@contextmanager
def graph_capture(pool=None, stream=None, capture_error_mode: str = "global", dump_path=None):
    g = torch.cuda.CUDAGraph()
    if dump_path is not None:
        g.enable_debug_mode()
    with torch.cuda.graph(cuda_graph=g, pool=pool, stream=stream, capture_error_mode=capture_error_mode):
        yield g
    if dump_path is not None:
        g.debug_dump(dump_path)

dist.init_process_group(backend="gloo", init_method="env://")
rank = dist.get_rank()
torch.cuda.set_device(rank)

from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator

pynccl = PyNcclCommunicator(group=dist.group.WORLD, device=rank)
pynccl.disabled = False

MAX_BATCHSIZE = 4

# Placeholder input used for capture
static_a = torch.zeros((MAX_BATCHSIZE, 1024), device="cuda")

def compute(batchsize):
    pynccl.all_reduce(static_a[:batchsize], stream=torch.cuda.current_stream())

# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
    for i in range(1, MAX_BATCHSIZE + 1):
        compute(i)
torch.cuda.current_stream().wait_stream(s)

def report_memory(prefix):
    free, total = torch.cuda.mem_get_info()
    used = total - free
    print(f"{prefix}: Used: {used / 1024 / 1024} MB, Free: {free / 1024 / 1024} MB, Total: {total / 1024 / 1024} MB")

# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
report_memory("Before capture")
graphs = [0] # 0 is a placeholder for 0 batchsize
memory_pool = None
for i in range(1, MAX_BATCHSIZE + 1):
    with graph_capture(pool=memory_pool) as g:
        compute(i)
    graphs.append(g)
    memory_pool = g.pool()
    report_memory(f"After capture batchsize {i}")
# Run the graph
static_a[:2] += 1
graphs[2].replay()
torch.cuda.current_stream().synchronize()
print(static_a[:2])

I call allreduce on a static buffer, and capture only this allreduce operation in cudagraph. Ideally it should not cost any memory.

When I run it with default setting:

INFO 05-30 18:33:53 utils.py:619] Found nccl from library libnccl.so.2
INFO 05-30 18:33:53 pynccl.py:65] vLLM is using nccl==2.20.5
INFO 05-30 18:33:53 utils.py:619] Found nccl from library libnccl.so.2
INFO 05-30 18:33:53 pynccl.py:65] vLLM is using nccl==2.20.5
Before capture: Used: 1373.375 MB, Free: 79677.25 MB, Total: 81050.625 MBBefore capture: Used: 1373.375 MB, Free: 79677.25 MB, Total: 81050.625 MB

After capture batchsize 1: Used: 1379.375 MB, Free: 79671.25 MB, Total: 81050.625 MBAfter capture batchsize 1: Used: 1379.375 MB, Free: 79671.25 MB, Total: 81050.625 MB

After capture batchsize 2: Used: 1381.375 MB, Free: 79669.25 MB, Total: 81050.625 MB
After capture batchsize 2: Used: 1381.375 MB, Free: 79669.25 MB, Total: 81050.625 MB
After capture batchsize 3: Used: 1383.375 MB, Free: 79667.25 MB, Total: 81050.625 MB
After capture batchsize 3: Used: 1383.375 MB, Free: 79667.25 MB, Total: 81050.625 MB
After capture batchsize 4: Used: 1385.375 MB, Free: 79665.25 MB, Total: 81050.625 MB
After capture batchsize 4: Used: 1385.375 MB, Free: 79665.25 MB, Total: 81050.625 MB
tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.]], device='cuda:0')
tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.]], device='cuda:1')

every cudagraph takes 2 MB memory.

If I run it with export NCCL_CUMEM_ENABLE=0 :

INFO 05-30 18:30:52 utils.py:619] Found nccl from library libnccl.so.2
INFO 05-30 18:30:52 pynccl.py:65] vLLM is using nccl==2.20.5
INFO 05-30 18:30:52 utils.py:619] Found nccl from library libnccl.so.2
INFO 05-30 18:30:52 pynccl.py:65] vLLM is using nccl==2.20.5
Before capture: Used: 1181.375 MB, Free: 79869.25 MB, Total: 81050.625 MBBefore capture: Used: 1181.375 MB, Free: 79869.25 MB, Total: 81050.625 MB

After capture batchsize 1: Used: 1185.375 MB, Free: 79865.25 MB, Total: 81050.625 MB
After capture batchsize 1: Used: 1185.375 MB, Free: 79865.25 MB, Total: 81050.625 MB
After capture batchsize 2: Used: 1185.375 MB, Free: 79865.25 MB, Total: 81050.625 MB
After capture batchsize 2: Used: 1185.375 MB, Free: 79865.25 MB, Total: 81050.625 MB
After capture batchsize 3: Used: 1185.375 MB, Free: 79865.25 MB, Total: 81050.625 MB
After capture batchsize 3: Used: 1185.375 MB, Free: 79865.25 MB, Total: 81050.625 MB
After capture batchsize 4: Used: 1185.375 MB, Free: 79865.25 MB, Total: 81050.625 MB
After capture batchsize 4: Used: 1185.375 MB, Free: 79865.25 MB, Total: 81050.625 MB
tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.]], device='cuda:1')
tensor([[2., 2., 2.,  ..., 2., 2., 2.],
        [2., 2., 2.,  ..., 2., 2., 2.]], device='cuda:0')

the memory cost does not increase when I capture more graphs .

@youkaichao
Copy link
Author

@sjeaugey I just tried nccl 2.21.5 now, the problem still exists. I doubt if this is because cuMemCreate is captured in cuda graph? I don't see any documentation explaining the behavior of cuMemCreate under graph capture and graph execution.

@sjeaugey
Copy link
Member

sjeaugey commented Jun 3, 2024

For each CUDA graph capture, we allocate some memory on the GPU to store the information related to that persistent operation:
https://github.com/NVIDIA/nccl/blob/master/src/enqueue.cc#L1094

With CUMEM enabled, each allocation has to be aligned to the mem granularity, i.e. 2MB, so it is not surprising you see 2MB allocated per graph.

I guess we'd need to add a sub-allocator to the CUMEM code to avoid allocating 2MB for each CUDA graph. CC @AddyLaddy

@youkaichao
Copy link
Author

@sjeaugey so it is not 2MB per cuda graph, it is 2MB per allreduce operation per cuda graph.

In total, it will cost 2MB * # allreduce * # graphs, which accumulates to GB in our case.

In my minimal reproducible example, when I add multiple allreduce operations, e.g.

def compute(batchsize):
    pynccl.all_reduce(static_a[:batchsize], stream=torch.cuda.current_stream())
    pynccl.all_reduce(static_a[:batchsize], stream=torch.cuda.current_stream())

I do see the memory overhead grows proportionally (for 2 allreduce operations, each graph takes 4MB more memory):

Before capture: Used: 491.6875 MB, Free: 32002.4375 MB, Total: 32494.125 MB
Before capture: Used: 491.6875 MB, Free: 32002.4375 MB, Total: 32494.125 MB
After capture batchsize 1: Used: 499.6875 MB, Free: 31994.4375 MB, Total: 32494.125 MB
After capture batchsize 1: Used: 499.6875 MB, Free: 31994.4375 MB, Total: 32494.125 MB
After capture batchsize 2: Used: 503.6875 MB, Free: 31990.4375 MB, Total: 32494.125 MB
After capture batchsize 2: Used: 503.6875 MB, Free: 31990.4375 MB, Total: 32494.125 MB
After capture batchsize 3: Used: 507.6875 MB, Free: 31986.4375 MB, Total: 32494.125 MB
After capture batchsize 3: Used: 507.6875 MB, Free: 31986.4375 MB, Total: 32494.125 MB
After capture batchsize 4: Used: 511.6875 MB, Free: 31982.4375 MB, Total: 32494.125 MB
After capture batchsize 4: Used: 511.6875 MB, Free: 31982.4375 MB, Total: 32494.125 MB
tensor([[4., 4., 4.,  ..., 4., 4., 4.],
        [4., 4., 4.,  ..., 4., 4., 4.]], device='cuda:0')
tensor([[4., 4., 4.,  ..., 4., 4., 4.],
        [4., 4., 4.,  ..., 4., 4., 4.]], device='cuda:1')

Hope it can be fixed soon 🙏

@sjeaugey
Copy link
Member

sjeaugey commented Jun 4, 2024

You're right, it's 2MB per operation within the graph. My previous comment still applies though; we'd need to implement a sub-allocator for CUMEM operations to reduce the memory usage with CUMEM.

@youkaichao
Copy link
Author

how does cumem related api behave under graph capture? the documentation does not say anything about it.

@davidthomas426
Copy link

Is this right?

  • When using cumem ops, graph capture ignores it -- instead, by using the lower level cumem ops, you are directly allocating each time you run the cuda graph capture. I'm not sure how cuMemRelease is handled with graph capture though...
  • When not using cumem ops, cudaMallocAsync and cudaFreeAsync are basically used instead. These use a stream-ordered suballocator with special behavior during graph capture, so that buffers can get reused for non-overlapping operations and are eligible to become part of the shared memory pool across cuda graphs.

@sjeaugey Could you clarify the behavior of cumem ops with cuda graphs, especially what happens when pointers are captured that are then released with cuMemRelease? Does this become a use-after-free? If not, how?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants