diff --git a/python/src/triton.cc b/python/src/triton.cc index e331254d3769..c9ef967fdd3a 100644 --- a/python/src/triton.cc +++ b/python/src/triton.cc @@ -1583,54 +1583,58 @@ void init_triton_translation(py::module &m) { m.def("compile_ptx_to_cubin", [](const std::string &ptxCode, const std::string &ptxasPath, int capability) -> py::object { - py::gil_scoped_release allow_threads; + std::string cubin; + { + py::gil_scoped_release allow_threads; - // compile ptx with ptxas - llvm::SmallString<64> fsrc; - llvm::SmallString<64> flog; - llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc); - llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog); - std::string fbin = std::string(fsrc) + ".o"; - llvm::FileRemover logRemover(flog); - llvm::FileRemover binRemover(fbin); - const char *_fsrc = fsrc.c_str(); - const char *_flog = flog.c_str(); - const char *_fbin = fbin.c_str(); - std::ofstream ofs(_fsrc); - ofs << ptxCode << std::endl; - ofs.close(); - std::string cmd; - int err; - cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) + - (capability == 90 ? "a " : " ") + _fsrc + " -o " + _fsrc + - ".o 2> " + _flog; + // compile ptx with ptxas + llvm::SmallString<64> fsrc; + llvm::SmallString<64> flog; + llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc); + llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog); + std::string fbin = std::string(fsrc) + ".o"; + llvm::FileRemover logRemover(flog); + llvm::FileRemover binRemover(fbin); + const char *_fsrc = fsrc.c_str(); + const char *_flog = flog.c_str(); + const char *_fbin = fbin.c_str(); + std::ofstream ofs(_fsrc); + ofs << ptxCode << std::endl; + ofs.close(); + std::string cmd; + int err; + cmd = ptxasPath + " -v --gpu-name=sm_" + std::to_string(capability) + + (capability == 90 ? "a " : " ") + _fsrc + " -o " + _fsrc + + ".o 2> " + _flog; - err = system(cmd.c_str()); - if (err != 0) { - err >>= 8; - std::ifstream _log(_flog); - std::string log(std::istreambuf_iterator(_log), {}); - if (err == 255) { - throw std::runtime_error("Internal Triton PTX codegen error: \n" + - log); - } else if (err == 128 + SIGSEGV) { - throw std::runtime_error("Please run `ptxas " + fsrc.str().str() + - "` to confirm that this is a " - "bug in `ptxas`\n" + - log); + err = system(cmd.c_str()); + if (err != 0) { + err >>= 8; + std::ifstream _log(_flog); + std::string log(std::istreambuf_iterator(_log), {}); + if (err == 255) { + throw std::runtime_error("Internal Triton PTX codegen error: \n" + + log); + } else if (err == 128 + SIGSEGV) { + throw std::runtime_error("Please run `ptxas " + fsrc.str().str() + + "` to confirm that this is a " + "bug in `ptxas`\n" + + log); + } else { + throw std::runtime_error("`ptxas` failed with error code " + + std::to_string(err) + ": \n" + log); + } + return {}; } else { - throw std::runtime_error("`ptxas` failed with error code " + - std::to_string(err) + ": \n" + log); + llvm::FileRemover srcRemover(fsrc); + std::ifstream _cubin(_fbin, std::ios::binary); + cubin = std::string(std::istreambuf_iterator(_cubin), {}); + _cubin.close(); + // Do not return here, exit the gil scope and return below } - return {}; - } else { - llvm::FileRemover srcRemover(fsrc); - std::ifstream _cubin(_fbin, std::ios::binary); - std::string cubin(std::istreambuf_iterator(_cubin), {}); - _cubin.close(); - py::bytes bytes(cubin); - return std::move(bytes); } + py::bytes bytes(cubin); + return std::move(bytes); }); m.def("add_external_libs",