-
Notifications
You must be signed in to change notification settings - Fork 827
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Report of increased memory overhead during cudagraph capture with nccl >= 2.19 #1234
Comments
Thanks for the report. @jbachan did we change something in 2.19 regarding the way we store graph-related data for NCCL? @youkaichao how many graphs do you expect to create? A power of ten is ok, it doesn't have to be precise; in my mind apps would create a couple of graphs (less than 10 for sure) so even 100MB per graph doesn't seem like a big problem. If the expectation is that we should allow for tens or hundred of graphs, it could influence how we fix the issue and design things in the future. |
About 30, each for a different batch size (8, 16, 24, ..., 256). Technically we capture the graph with large batchsize first, followed by smaller batchsize. We find cudagraph with smaller batchsize can share the memory buffer of larger batchsize. So even with more than 30 cudagraph, the memory overhead prior to nccl 2.19 is low. |
@sjeaugey hi, any update on this? 👀 |
@sjeaugey hi, any update on this? 👀 |
@sjeaugey we finally find this problem is related with the virtual memory usage, and solve it in vllm-project/vllm#5091 . It would be great if you can share why virtual memory is allocated during graph capture, and what it is used for, what's the performance impact of turning it off 🙏 |
Thanks for the feedback. I'm not sure I understand how you solved it though. Was it a proper fix or a workaround? From what I understand, you're saying that enabling CUMEM increases the memory usage when using CUDA graphs. Is that accurate? Did you confirm the problem went away with |
Yes, the fix is to set |
I would say this is only a workaround. I don't know why NCCL costs more memory with cudagraph when |
here is a minimal reproducible example: import torch
import torch.distributed as dist
from contextlib import contextmanager
@contextmanager
def graph_capture(pool=None, stream=None, capture_error_mode: str = "global", dump_path=None):
g = torch.cuda.CUDAGraph()
if dump_path is not None:
g.enable_debug_mode()
with torch.cuda.graph(cuda_graph=g, pool=pool, stream=stream, capture_error_mode=capture_error_mode):
yield g
if dump_path is not None:
g.debug_dump(dump_path)
dist.init_process_group(backend="gloo", init_method="env://")
rank = dist.get_rank()
torch.cuda.set_device(rank)
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
pynccl = PyNcclCommunicator(group=dist.group.WORLD, device=rank)
pynccl.disabled = False
MAX_BATCHSIZE = 4
# Placeholder input used for capture
static_a = torch.zeros((MAX_BATCHSIZE, 1024), device="cuda")
def compute(batchsize):
pynccl.all_reduce(static_a[:batchsize], stream=torch.cuda.current_stream())
# Warmup before capture
s = torch.cuda.Stream()
s.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(s):
for i in range(1, MAX_BATCHSIZE + 1):
compute(i)
torch.cuda.current_stream().wait_stream(s)
def report_memory(prefix):
free, total = torch.cuda.mem_get_info()
used = total - free
print(f"{prefix}: Used: {used / 1024 / 1024} MB, Free: {free / 1024 / 1024} MB, Total: {total / 1024 / 1024} MB")
# Captures the graph
# To allow capture, automatically sets a side stream as the current stream in the context
report_memory("Before capture")
graphs = [0] # 0 is a placeholder for 0 batchsize
memory_pool = None
for i in range(1, MAX_BATCHSIZE + 1):
with graph_capture(pool=memory_pool) as g:
compute(i)
graphs.append(g)
memory_pool = g.pool()
report_memory(f"After capture batchsize {i}")
# Run the graph
static_a[:2] += 1
graphs[2].replay()
torch.cuda.current_stream().synchronize()
print(static_a[:2]) I call allreduce on a static buffer, and capture only this allreduce operation in cudagraph. Ideally it should not cost any memory. When I run it with default setting:
every cudagraph takes 2 MB memory. If I run it with
the memory cost does not increase when I capture more graphs . |
@sjeaugey I just tried nccl 2.21.5 now, the problem still exists. I doubt if this is because |
For each CUDA graph capture, we allocate some memory on the GPU to store the information related to that persistent operation: With CUMEM enabled, each allocation has to be aligned to the mem granularity, i.e. 2MB, so it is not surprising you see 2MB allocated per graph. I guess we'd need to add a sub-allocator to the CUMEM code to avoid allocating 2MB for each CUDA graph. CC @AddyLaddy |
@sjeaugey so it is not 2MB per cuda graph, it is 2MB per allreduce operation per cuda graph. In total, it will cost In my minimal reproducible example, when I add multiple allreduce operations, e.g. def compute(batchsize):
pynccl.all_reduce(static_a[:batchsize], stream=torch.cuda.current_stream())
pynccl.all_reduce(static_a[:batchsize], stream=torch.cuda.current_stream()) I do see the memory overhead grows proportionally (for 2 allreduce operations, each graph takes 4MB more memory):
Hope it can be fixed soon 🙏 |
You're right, it's 2MB per operation within the graph. My previous comment still applies though; we'd need to implement a sub-allocator for CUMEM operations to reduce the memory usage with CUMEM. |
how does cumem related api behave under graph capture? the documentation does not say anything about it. |
Is this right?
@sjeaugey Could you clarify the behavior of cumem ops with cuda graphs, especially what happens when pointers are captured that are then released with cuMemRelease? Does this become a use-after-free? If not, how? |
Hi, I would like to report a memory issue with
nccl
. A reproducible example is attached below:In a gcp g2-standard-24 instance (with 2 L4 GPUs):
Note that the code manually links against a pre-downloaded
nccl==2.18.3
. There is also anccl==2.19.3
available inside the image, the path is/usr/local/lib/python3.10/dist-packages/nvidia/nccl/lib/libnccl.so.2
.By adding the following line in
/vllm-workspace/vllm/model_executor/parallel_utils/pynccl.py
, beforenccl = ctypes.CDLL(so_file)
:We can force the program to use nccl 2.19.3, and we will get an OOM error.
The background:
In distributed inference, https://github.com/vllm-project/vllm uses nccl together with cudagraph. We capture about 30 graphs with different batch sizes. The memory overhead when we use
pytorch 2.1.2
(withnccl==2.18.3
) is nearly zero (about 10MB per graph, and sometimes is zero); however, when we upgrade topytorch 2.2.0
(withnccl==2.19.3
), the memory overhead is more than 100MB per graph.We spent more than a week (to be honest, more time than one would feel comfortable with) to investigate the issue. We used to think it should be related with pytorch, but finally we find the problem comes from the
nccl
library.For more code on measuring the memory overhead, please check vllm-project/vllm#3442 (comment) .
It would be very helpful if the
nccl
team can point our the root cause of the memory overhead, and potential knobs to control it (e.g. via some environment variables). The above problem happens for bothnccl==2.19.3
andnccl==2.20.5
.Thank you for your time.
The text was updated successfully, but these errors were encountered: