You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The output of `python collect_env.py`
PyTorch version: 2.3.0+cu121
Is debug build: False
CUDA used to build PyTorch: 12.1
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: 14.0.0-1ubuntu1.1
CMake version: version 3.30.0
Libc version: glibc-2.35
Python version: 3.10.12 (main, Nov 20 2023, 15:14:05) [GCC 11.4.0] (64-bit runtime)
Python platform: Linux-5.4.0-48-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.3.107
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration:
GPU 0: NVIDIA A800-SXM4-80GB
GPU 1: NVIDIA A800-SXM4-80GB
GPU 2: NVIDIA A800-SXM4-80GB
GPU 3: NVIDIA A800-SXM4-80GB
GPU 4: NVIDIA A800-SXM4-80GB
GPU 5: NVIDIA A800-SXM4-80GB
GPU 6: NVIDIA A800-SXM4-80GB
GPU 7: NVIDIA A800-SXM4-80GB
Nvidia driver version: 535.129.03
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.0.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.0.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 46 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 128
On-line CPU(s) list: 0-127
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8358 CPU @ 2.60GHz
CPU family: 6
Model: 106
Thread(s) per core: 2
Core(s) per socket: 32
Socket(s): 2
Stepping: 6
Frequency boost: enabled
CPU max MHz: 3400.0000
CPU min MHz: 800.0000
BogoMIPS: 5200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 invpcid_single ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq rdpid md_clear pconfig flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 3 MiB (64 instances)
L1i cache: 2 MiB (64 instances)
L2 cache: 80 MiB (64 instances)
L3 cache: 96 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-31,64-95
NUMA node1 CPU(s): 32-63,96-127
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-nccl-cu12==2.20.5
[pip3] torch==2.3.0
[pip3] torchvision==0.18.0
[pip3] transformers==4.42.3
[pip3] triton==2.3.0
[pip3] triton-nightly==3.0.0.post20240626041721
[conda] Could not collect
ROCM Version: Could not collect
Neuron SDK Version: N/A
vLLM Version: 0.5.1
vLLM Build Flags:
CUDA Archs: Not Set; ROCm: Disabled; Neuron: Disabled
GPU Topology:
GPU0 GPU1 GPU2 GPU3 GPU4 GPU5 GPU6 GPU7 NIC0 NIC1 NIC2 NIC3 NIC4 CPU Affinity NUMA Affinity GPU NUMA ID
GPU0 X NV8 NV8 NV8 NV8 NV8 NV8 NV8 PXB SYS SYS SYS SYS 0-31,64-95 0 N/A
GPU1 NV8 X NV8 NV8 NV8 NV8 NV8 NV8 PXB SYS SYS SYS SYS 0-31,64-95 0 N/A
GPU2 NV8 NV8 X NV8 NV8 NV8 NV8 NV8 SYS PXB SYS SYS SYS 0-31,64-95 0 N/A
GPU3 NV8 NV8 NV8 X NV8 NV8 NV8 NV8 SYS PXB SYS SYS SYS 0-31,64-95 0 N/A
GPU4 NV8 NV8 NV8 NV8 X NV8 NV8 NV8 SYS SYS PXB SYS SYS 32-63,96-127 1 N/A
GPU5 NV8 NV8 NV8 NV8 NV8 X NV8 NV8 SYS SYS PXB SYS SYS 32-63,96-127 1 N/A
GPU6 NV8 NV8 NV8 NV8 NV8 NV8 X NV8 SYS SYS SYS PXB SYS 32-63,96-127 1 N/A
GPU7 NV8 NV8 NV8 NV8 NV8 NV8 NV8 X SYS SYS SYS PXB SYS 32-63,96-127 1 N/A
NIC0 PXB PXB SYS SYS SYS SYS SYS SYS X SYS SYS SYS SYS
NIC1 SYS SYS PXB PXB SYS SYS SYS SYS SYS X SYS SYS SYS
NIC2 SYS SYS SYS SYS PXB PXB SYS SYS SYS SYS X SYS SYS
NIC3 SYS SYS SYS SYS SYS SYS PXB PXB SYS SYS SYS X SYS
NIC4 SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS SYS X
Legend:
X = Self
SYS = Connection traversing PCIe as well as the SMP interconnect between NUMA nodes (e.g., QPI/UPI)
NODE = Connection traversing PCIe as well as the interconnect between PCIe Host Bridges within a NUMA node
PHB = Connection traversing PCIe as well as a PCIe Host Bridge (typically the CPU)
PXB = Connection traversing multiple PCIe bridges (without traversing the PCIe Host Bridge)
PIX = Connection traversing at most a single PCIe bridge
NV# = Connection traversing a bonded set of # NVLinks
NIC Legend:
NIC0: mlx5_0
NIC1: mlx5_1
NIC2: mlx5_4
NIC3: mlx5_5
NIC4: mlx5_bond_0
🐛 Describe the bug
my code (Some unimportant code was removed):
importasyncioimportjsonimportosimportthreadingfromtypingimportAsyncGeneratorimportnumpyasnpimporttriton_python_backend_utilsaspb_utilsfromvllmimportSamplingParamsfromvllm.engine.arg_utilsimportAsyncEngineArgsfromvllm.engine.async_llm_engineimportAsyncLLMEnginefromvllm.utilsimportrandom_uuidimportsyssys.path.append("/model")
_VLLM_ENGINE_ARGS_FILENAME="model.json"classTritonPythonModel:
@staticmethoddefauto_complete_config(auto_complete_model_config):
'''init params'''returnauto_complete_model_configdefinitialize(self, args):
self.logger=pb_utils.Loggerself.model_config=json.loads(args["model_config"])
# assert are in decoupled mode. Currently, Triton needs to use# decoupled policy for asynchronously forwarding requests to# vLLM engine.self.using_decoupled=pb_utils.using_decoupled_model_transaction_policy(
self.model_config
)
assert (
self.using_decoupled
), "vLLM Triton backend must be configured to use decoupled model transaction policy"engine_args_filepath=os.path.join(
pb_utils.get_model_dir(), _VLLM_ENGINE_ARGS_FILENAME
)
withopen(engine_args_filepath) asfile:
vllm_engine_config=json.load(file)
# Create an AsyncLLMEngine from the config from JSONself.logger.log_info(f"engine_args: {vllm_engine_config}")
self.llm_engine=AsyncLLMEngine.from_engine_args(
AsyncEngineArgs(**vllm_engine_config)
)
output_config=pb_utils.get_output_config_by_name(
self.model_config, "text_output"
)
self.output_dtype=pb_utils.triton_string_to_numpy(output_config["data_type"])
# Counter to keep track of ongoing request countsself.ongoing_request_count=0# Starting asyncio event loop to process the received requests asynchronously.self._loop=asyncio.get_event_loop()
self._loop_thread=threading.Thread(
target=self.engine_loop, args=(self._loop,)
)
self._shutdown_event=asyncio.Event()
self._loop_thread.start()
defcreate_task(self, coro):
""" Creates a task on the engine's event loop which is running on a separate thread. """assert (
self._shutdown_event.is_set() isFalse
), "Cannot create tasks after shutdown has been requested"returnasyncio.run_coroutine_threadsafe(coro, self._loop)
defengine_loop(self, loop):
""" Runs the engine's event loop on a separate thread. """asyncio.set_event_loop(loop)
self._loop.run_until_complete(self.await_shutdown())
asyncdefawait_shutdown(self):
""" Primary coroutine running on the engine event loop. This coroutine is responsible for keeping the engine alive until a shutdown is requested. """# first await the shutdown signalwhileself._shutdown_event.is_set() isFalse:
awaitasyncio.sleep(5)
# Wait for the ongoing_requestswhileself.ongoing_request_count>0:
self.logger.log_info(
"[vllm] Awaiting remaining {} requests".format(
self.ongoing_request_count
)
)
awaitasyncio.sleep(5)
fortaskinasyncio.all_tasks(loop=self._loop):
iftaskisnotasyncio.current_task():
task.cancel()
self.logger.log_info("[vllm] Shutdown complete")
defget_sampling_params_dict(self, params_json):
""" This functions parses the dictionary values into their expected format. """defcreate_response(self, vllm_output, text_len=0):
""" Parses the output from the vLLM engine into Triton response. """prompt=vllm_output.prompttext_outputs= []
foroutputinvllm_output.outputs:
output_text=prompt+output.texttext_outputs.append(output_text[text_len:].encode("utf-8"))
text_len=len(output_text)
triton_output_tensor=pb_utils.Tensor(
"text_output", np.asarray(text_outputs, dtype=self.output_dtype)
)
model_version_tensor=pb_utils.Tensor(
"wps_model_version", np.asarray([self.model_version.encode("utf-8")], dtype=self.model_version_dtype)
)
returnpb_utils.InferenceResponse(output_tensors=[triton_output_tensor, model_version_tensor]), text_lenasyncdefgenerate(self, request):
""" Forwards single request to LLM engine and returns responses. """response_sender=request.get_response_sender()
self.ongoing_request_count+=1try:
prompt=pb_utils.get_input_tensor_by_name(
request, "PROMPTS"
).as_numpy()[0]
ifisinstance(prompt, bytes):
prompt=prompt.decode("utf-8")
# Request parameters are not yet supported via# BLS. Provide an optional mechanism to receive serialized# parameters as an input tensor until support is addedparameters_input_tensor=pb_utils.get_input_tensor_by_name(
request, "sampling_parameters"
)
ifparameters_input_tensor:
parameters=parameters_input_tensor.as_numpy()[0].decode("utf-8")
else:
parameters=request.parameters()
sampling_params_dict=self.get_sampling_params_dict(parameters)
sampling_params=SamplingParams(**sampling_params_dict)
last_output=Nonestream_start_index=0asyncforoutputinself.llm_engine.generate(
prompt, sampling_params
):
ifresponse_sender.is_cancelled():
self.logger.log_info("[vllm] Cancelling the request")
awaitself.llm_engine.abort(request_id)
self.logger.log_info("[vllm] Successfully cancelled the request")
breakifstream:
response, output_len=self.create_response(output, stream_start_index)
response_sender.send(response)
stream_start_index=output_lenelse:
last_output=outputifnotstream:
response_sender.send(self.create_response(last_output)[0])
exceptExceptionase:
self.logger.log_info(f"[vllm] Error generating stream: {e}")
error=pb_utils.TritonError(f"Error generating stream: {e}")
triton_output_tensor=pb_utils.Tensor(
"text_output", np.asarray(["N/A"], dtype=self.output_dtype)
)
response=pb_utils.InferenceResponse(
output_tensors=[triton_output_tensor], error=error
)
response_sender.send(response)
raiseefinally:
response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL)
self.ongoing_request_count-=1defexecute(self, requests):
""" Triton core issues requests to the backend via this method. When this method returns, new requests can be issued to the backend. Blocking this function would prevent the backend from pulling additional requests from Triton into the vLLM engine. This can be done if the kv cache within vLLM engine is too loaded. We are pushing all the requests on vllm and let it handle the full traffic. """forrequestinrequests:
self.create_task(self.generate(request))
returnNonedeffinalize(self):
""" Triton virtual method; called when the model is unloaded. """self.logger.log_info("[vllm] Issuing finalize to vllm backend")
self._shutdown_event.set()
ifself._loop_threadisnotNone:
self._loop_thread.join()
self._loop_thread=None
The following error occurs occasionally during the call:
ERROR:asyncio:Task exception was never retrieved
future: <Task finished name='Task-424643' coro=<<async_generator_athrow without __name__>()> exception=RuntimeError('aclose(): asynchronous generator is already running')>
RuntimeError: aclose(): asynchronous generator is already running
ERROR:asyncio:Task was destroyed but it is pending!
task: <Task pending name='Task-424498' coro=<TritonPythonModel.generate() done, defined at /triton_deploy/model_common.py:278> wait_for=<Future pending cb=[Task.task_wakeup()]> cb=[_chain_future.<locals>._call_set_state() at /usr/lib/python3.10/asyncio/futures.py:392]>
ERROR:asyncio:Task exception was never retrieved
future: <Task finished name='Task-424644' coro=<<async_generator_athrow without __name__>()> exception=RuntimeError('aclose(): asynchronous generator is already running')>
RuntimeError: aclose(): asynchronous generator is already running
Then the error is as follows:
ERROR 07-10 13:59:59 async_llm_engine.py:53] Engine background task failed
ERROR 07-10 13:59:59 async_llm_engine.py:53] Traceback (most recent call last):
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 43, in _log_task_completion
ERROR 07-10 13:59:59 async_llm_engine.py:53] return_value = task.result()
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 595, in run_engine_loop
ERROR 07-10 13:59:59 async_llm_engine.py:53] result = task.result()
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 540, in engine_step
ERROR 07-10 13:59:59 async_llm_engine.py:53] request_outputs = await self.engine.step_async(virtual_engine)
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 251, in step_async
ERROR 07-10 13:59:59 async_llm_engine.py:53] self.do_log_stats(scheduler_outputs, output)
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 903, in do_log_stats
ERROR 07-10 13:59:59 async_llm_engine.py:53] logger.log(self._get_stats(scheduler_outputs, model_output))
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/vllm/engine/metrics.py", line 429, in log
ERROR 07-10 13:59:59 async_llm_engine.py:53] self._log_prometheus(stats)
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/vllm/engine/metrics.py", line 386, in _log_prometheus
ERROR 07-10 13:59:59 async_llm_engine.py:53] self._log_counter(self.metrics.counter_generation_tokens,
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/vllm/engine/metrics.py", line 354, in _log_counter
ERROR 07-10 13:59:59 async_llm_engine.py:53] counter.labels(**self.labels).inc(data)
ERROR 07-10 13:59:59 async_llm_engine.py:53] File "/usr/local/lib/python3.10/dist-packages/prometheus_client/metrics.py", line 313, in inc
ERROR 07-10 13:59:59 async_llm_engine.py:53] raise ValueError('Counters can only be incremented by non-negative amounts.')
ERROR 07-10 13:59:59 async_llm_engine.py:53] ValueError: Counters can only be incremented by non-negative amounts.
ERROR:asyncio:Exception in callback _log_task_completion(error_callback=<bound method...7f351ecc6170>>)(<Task finishe...ve amounts.')>) at /usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py:33
The text was updated successfully, but these errors were encountered:
This issue has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this issue should remain open. Thank you!
Your current environment
🐛 Describe the bug
my code (Some unimportant code was removed):
The following error occurs occasionally during the call:
Then the error is as follows:
The text was updated successfully, but these errors were encountered: