Skip to content

Commit

Permalink
Fix all-reduce memory usage (vllm-project#2151)
Browse files Browse the repository at this point in the history
  • Loading branch information
WoosukKwon authored and jimpang committed Dec 18, 2023
1 parent 0a147e3 commit 11f7905
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ def __init__(
self.lora_manager = None

def init_model(self, cupy_port: Optional[int] = None):
# torch.distributed.all_reduce does not free the input tensor until
# the synchronization point. This causes the memory usage to grow
# as the number of all_reduce calls increases. This env var disables
# this behavior.
# Related issue:
# https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"

# This env var set by Ray causes exceptions with graph building.
os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None)
# Env vars will be set by Ray.
Expand Down

0 comments on commit 11f7905

Please sign in to comment.