Skip to content

Commit

Permalink
[Core] Move function tracing setup to util function (#4352)
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored Apr 25, 2024
1 parent 15e7c67 commit efffb63
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
21 changes: 20 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import asyncio
import datetime
import enum
import gc
import glob
import os
import socket
import subprocess
import tempfile
import threading
import uuid
import warnings
from collections import defaultdict
Expand All @@ -18,7 +21,7 @@
import torch
from packaging.version import Version, parse

from vllm.logger import init_logger
from vllm.logger import enable_trace_function_call, init_logger

T = TypeVar("T")
logger = init_logger(__name__)
Expand Down Expand Up @@ -607,3 +610,19 @@ def find_nccl_library():
raise ValueError("NCCL only supports CUDA and ROCm backends.")
logger.info(f"Found nccl from library {so_file}")
return so_file


def enable_trace_function_call_for_thread() -> None:
"""Set up function tracing for the current thread,
if enabled via the VLLM_TRACE_FUNCTION environment variable
"""

if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
18 changes: 4 additions & 14 deletions vllm/worker/worker_base.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import datetime
import importlib
import os
import tempfile
import threading
from abc import ABC, abstractmethod
from typing import Dict, List, Set, Tuple

from vllm.logger import enable_trace_function_call, init_logger
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import get_vllm_instance_id, update_environment_variables
from vllm.utils import (enable_trace_function_call_for_thread,
update_environment_variables)

logger = init_logger(__name__)

Expand Down Expand Up @@ -128,15 +126,7 @@ def init_worker(self, *args, **kwargs):
function tracing if required.
Arguments are passed to the worker class constructor.
"""
if int(os.getenv("VLLM_TRACE_FUNCTION", "0")):
tmp_dir = tempfile.gettempdir()
filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}"
f"_thread_{threading.get_ident()}_"
f"at_{datetime.datetime.now()}.log").replace(" ", "_")
log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(),
filename)
os.makedirs(os.path.dirname(log_path), exist_ok=True)
enable_trace_function_call(log_path)
enable_trace_function_call_for_thread()

mod = importlib.import_module(self.worker_module_name)
worker_class = getattr(mod, self.worker_class_name)
Expand Down

0 comments on commit efffb63

Please sign in to comment.