From 131c4497754664e0923469ffffbdca009579f2c5 Mon Sep 17 00:00:00 2001 From: Someone Serge Date: Wed, 25 Oct 2023 17:08:23 +0300 Subject: [PATCH] cuda executables: make optional --- python/setup.py | 26 +++++++++++++++++++++----- python/triton/common/backend.py | 10 +++++++--- 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/python/setup.py b/python/setup.py index e07aa7965112..4bec219f1252 100644 --- a/python/setup.py +++ b/python/setup.py @@ -124,7 +124,9 @@ def get_thirdparty_packages(triton_cache_path): # ---- package data --- -def download_and_copy(src_path, version, url_func): +def download_and_copy(src_path, variable, version, url_func): + if variable in os.environ: + return base_dir = os.path.dirname(__file__) arch = platform.machine() if arch == "x86_64": @@ -150,7 +152,6 @@ def download_and_copy(src_path, version, url_func): src_path = os.path.join(temp_dir, src_path) os.makedirs(os.path.split(dst_path)[0], exist_ok=True) shutil.copy(src_path, dst_path) - return dst_suffix # ---- cmake extension ---- @@ -298,9 +299,24 @@ def build_extension(self, ext): subprocess.check_call(["cmake", "--build", ".", "--target", "mlir-doc"], cwd=cmake_dir) -download_and_copy(src_path='bin/ptxas', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2") -download_and_copy(src_path='bin/cuobjdump', version='12.1.111', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2") -download_and_copy(src_path='bin/nvdisasm', version='12.1.105', url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2") +download_and_copy( + src_path="bin/ptxas", + variable="TRITON_PTXAS_PATH", + version="12.1.105", + url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvcc-{version}-0.tar.bz2", +) +download_and_copy( + src_path="bin/cuobjdump", + variable="TRITON_CUOBJDUMP_PATH", + version="12.1.111", + url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-cuobjdump-{version}-0.tar.bz2", +) +download_and_copy( + src_path="bin/nvdisasm", + variable="TRITON_NVDISASM_PATH", + version="12.1.105", + url_func=lambda arch, version: f"https://conda.anaconda.org/nvidia/label/cuda-12.1.1/linux-{arch}/cuda-nvdisasm-{version}-0.tar.bz2", +) setup( name="triton", diff --git a/python/triton/common/backend.py b/python/triton/common/backend.py index aaf35334bef3..1aa8c9fe481d 100644 --- a/python/triton/common/backend.py +++ b/python/triton/common/backend.py @@ -108,7 +108,7 @@ def get_backend(device_type: str): def _path_to_binary(binary: str): base_dir = os.path.join(os.path.dirname(__file__), os.pardir) paths = [ - os.environ.get("TRITON_PTXAS_PATH", ""), + os.environ.get(f"TRITON_{binary.upper()}_PATH", ""), os.path.join(base_dir, "third_party", "cuda", "bin", binary) ] @@ -174,6 +174,10 @@ def get_cuda_version_key(): global _cached_cuda_version_key if _cached_cuda_version_key is None: key = compute_core_version_key() - ptxas = path_to_ptxas()[0] - _cached_cuda_version_key = key + '-' + hashlib.sha1(subprocess.check_output([ptxas, "--version"])).hexdigest() + try: + ptxas = path_to_ptxas()[0] + ptxas_version = subprocess.check_output([ptxas, "--version"]) + except RuntimeError: + ptxas_version = b"NO_PTXAS" + _cached_cuda_version_key = key + '-' + hashlib.sha1(ptxas_version).hexdigest() return _cached_cuda_version_key