Skip to content

Commit

Permalink
Make it easy to profile workers with nsight (vllm-project#3162)
Browse files Browse the repository at this point in the history
Co-authored-by: Roger Wang <[email protected]>
  • Loading branch information
2 people authored and dbogunowicz committed Mar 26, 2024
1 parent 3331d5b commit c0df374
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 2 deletions.
6 changes: 6 additions & 0 deletions benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def main(args: argparse.Namespace):
enforce_eager=args.enforce_eager,
kv_cache_dtype=args.kv_cache_dtype,
device=args.device,
ray_workers_use_nsight=args.ray_workers_use_nsight,
)

sampling_params = SamplingParams(
Expand Down Expand Up @@ -145,5 +146,10 @@ def run_to_completion(profile_dir: Optional[str] = None):
default="cuda",
choices=["cuda"],
help='device type for vLLM execution, supporting CUDA only currently.')
parser.add_argument(
"--ray-workers-use-nsight",
action='store_true',
help="If specified, use nsight to profile ray workers",
)
args = parser.parse_args()
main(args)
7 changes: 7 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,8 @@ class ParallelConfig:
parallel and large models.
disable_custom_all_reduce: Disable the custom all-reduce kernel and
fall back to NCCL.
ray_workers_use_nsight: Whether to profile Ray workers with nsight, see
https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler.
"""

def __init__(
Expand All @@ -391,6 +393,7 @@ def __init__(
worker_use_ray: bool,
max_parallel_loading_workers: Optional[int] = None,
disable_custom_all_reduce: bool = False,
ray_workers_use_nsight: bool = False,
) -> None:
self.pipeline_parallel_size = pipeline_parallel_size
if is_neuron():
Expand All @@ -404,6 +407,7 @@ def __init__(
self.worker_use_ray = worker_use_ray
self.max_parallel_loading_workers = max_parallel_loading_workers
self.disable_custom_all_reduce = disable_custom_all_reduce
self.ray_workers_use_nsight = ray_workers_use_nsight

self.world_size = pipeline_parallel_size * self.tensor_parallel_size
# Ray worker is not supported for Neuron backend.
Expand All @@ -426,6 +430,9 @@ def _verify_args(self) -> None:
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported with pipeline parallelism.")
if self.ray_workers_use_nsight and not self.worker_use_ray:
raise ValueError("Unable to use nsight profiling unless workers "
"run with Ray.")

# FIXME(woosuk): Fix the stability issues and re-enable the custom
# all-reduce kernel.
Expand Down
8 changes: 7 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class EngineArgs:
lora_dtype = 'auto'
max_cpu_loras: Optional[int] = None
device: str = 'auto'
ray_workers_use_nsight: bool = False

def __post_init__(self):
if self.tokenizer is None:
Expand Down Expand Up @@ -168,6 +169,10 @@ def add_cli_args(
help='load model sequentially in multiple batches, '
'to avoid RAM OOM when using tensor '
'parallel and large models')
parser.add_argument(
'--ray-workers-use-nsight',
action='store_true',
help='If specified, use nsight to profile ray workers')
# KV cache arguments
parser.add_argument('--block-size',
type=int,
Expand Down Expand Up @@ -305,7 +310,8 @@ def create_engine_configs(
self.tensor_parallel_size,
self.worker_use_ray,
self.max_parallel_loading_workers,
self.disable_custom_all_reduce)
self.disable_custom_all_reduce,
self.ray_workers_use_nsight)
scheduler_config = SchedulerConfig(self.max_num_batched_tokens,
self.max_num_seqs,
model_config.max_model_len,
Expand Down
15 changes: 14 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,20 @@ def __init__(
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
if ray_usage != "1":
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
self._init_workers_ray(placement_group)
# Pass additional arguments to initialize the worker
additional_ray_args = {}
if self.parallel_config.ray_workers_use_nsight:
logger.info("Configuring Ray workers to use nsight.")
additional_ray_args = {
"runtime_env": {
"nsight": {
"t": "cuda,cudnn,cublas",
"o": "'worker_process_%p'",
"cuda-graph-trace": "node",
}
}
}
self._init_workers_ray(placement_group, **additional_ray_args)
else:
self._init_workers()

Expand Down

0 comments on commit c0df374

Please sign in to comment.