Skip to content

Commit

Permalink
fix(export): update API for disabling device reassignment in TRTLLM f…
Browse files Browse the repository at this point in the history
…or Aligner (#10863)

* fix(export): update API for disabling device reassignment in TRTLLM for Aligner

[feat] Upgrade nemo-export path for aligner to TRTLLM-v12 and use python runtime

Signed-off-by: Terry Kong <[email protected]>

fix: forgot to always set _disable_torch_cuda_device_set

Signed-off-by: Terry Kong <[email protected]>

Signed-off-by: Terry Kong <[email protected]>

Apply isort and black reformatting

Signed-off-by: terrykong <[email protected]>

invert torch device set

Signed-off-by: Terry Kong <[email protected]>

* remove comment

Signed-off-by: Terry Kong <[email protected]>

---------

Signed-off-by: Terry Kong <[email protected]>
  • Loading branch information
terrykong authored Nov 12, 2024
1 parent 24e2871 commit 085e957
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,23 @@
from tensorrt_llm.lora_manager import LoraManager
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.quantization import QuantMode
from tensorrt_llm.runtime import GenerationSession, ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig
from tensorrt_llm.runtime import ModelConfig, ModelRunner, ModelRunnerCpp, SamplingConfig
from transformers import PreTrainedTokenizer

LOGGER = logging.getLogger("NeMo")

use_trtllm_bindings = True
try:
from tensorrt_llm.bindings import GptJsonConfig, KvCacheConfig, WorldConfig
from tensorrt_llm.bindings import GptJsonConfig
except Exception as e:
use_trtllm_bindings = False

TRTLLM_SUPPORTS_DEVICE_DISABLE = True
try:
from tensorrt_llm.runtime.generation import DISABLE_TORCH_DEVICE_SET
except (ImportError, ModuleNotFoundError):
TRTLLM_SUPPORTS_DEVICE_DISABLE = False


@dataclass
class TensorrtLLMHostContext:
Expand Down Expand Up @@ -494,12 +500,20 @@ def load_distributed(engine_dir, model_parallel_rank, gpus_per_node):
json_config_str = f.read()

engine = Engine.from_buffer(engine_buffer=engine_data, json_config_str=json_config_str, rank=model_parallel_rank)

if not TRTLLM_SUPPORTS_DEVICE_DISABLE:
raise RuntimeError(
f"TensorRT-LLM does not support torch device disabling. Please upgrade TensorRT-LLM to make use of this feature."
)
elif not DISABLE_TORCH_DEVICE_SET:
raise RuntimeError(
f"To use TensorRT-LLM's python ModelRunner API in load_distributed(...) you must set the env var DISABLE_TORCH_DEVICE_SET=1"
)
decoder = ModelRunner.from_engine(
engine=engine,
# We want the engine to have the mp_rank, but the python runtime to not resassign the device of the current process
# So we will set it to the current device
rank=torch.cuda.current_device(),
_disable_torch_cuda_device_set=True,
)

tensorrt_llm_worker_context.decoder = decoder
Expand Down

0 comments on commit 085e957

Please sign in to comment.