diff --git a/jaxlib/gpu/triton_kernels.cc b/jaxlib/gpu/triton_kernels.cc index cfb1d51fdf84..b921578557a0 100644 --- a/jaxlib/gpu/triton_kernels.cc +++ b/jaxlib/gpu/triton_kernels.cc @@ -1,17 +1,22 @@ #include "jaxlib/gpu/triton_kernels.h" #include -#include #include +#include +#include +#include #include #include #include #include #include +#include +#include #include #include #include "absl/base/optimization.h" +#include "absl/base/thread_annotations.h" #include "absl/cleanup/cleanup.h" #include "absl/container/flat_hash_map.h" #include "absl/log/check.h" @@ -26,11 +31,9 @@ #include "jaxlib/gpu/vendor.h" #include "xla/service/custom_call_status.h" #include "xla/stream_executor/gpu/asm_compiler.h" -#include "tsl/platform/env.h" #define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr)) - namespace jax::JAX_GPU_NAMESPACE { namespace { @@ -82,9 +85,9 @@ absl::StatusOr GetModuleImage(std::string kernel_name, auto it = module_images.find(key); if (it != module_images.end()) return it->second.get(); -#ifdef JAX_GPU_HIP //For HIP/ROCM just read the hsaco file +#ifdef JAX_GPU_HIP // For HIP/ROCM just read the hsaco file std::string result_blob; - std::string fname{ptx}; + std::string fname{ptx}; TF_RETURN_IF_ERROR( tsl::ReadFileToString(tsl::Env::Default(), fname, &result_blob)); std::vector module_image(result_blob.begin(), result_blob.end()); @@ -230,10 +233,10 @@ class ModuleImage { } if (shared_optin > kMaxStaticSharedMemBytes) { - #ifdef JAX_GPU_CUDA - GPU_RETURN_IF_ERROR( +#ifdef JAX_GPU_CUDA + GPU_RETURN_IF_ERROR( gpuFuncSetCacheConfig(function, CU_FUNC_CACHE_PREFER_SHARED)); - #endif +#endif int shared_total; GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute( &shared_total, @@ -241,11 +244,11 @@ class ModuleImage { int shared_static; GPU_RETURN_IF_ERROR(gpuFuncGetAttribute( &shared_static, GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function)); - #ifdef JAX_GPU_CUDA - GPU_RETURN_IF_ERROR(cuFuncSetAttribute( +#ifdef JAX_GPU_CUDA + GPU_RETURN_IF_ERROR(cuFuncSetAttribute( function, GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static)); - #endif +#endif } return function; } @@ -257,7 +260,8 @@ class ModuleImage { absl::Mutex mutex_; std::vector modules_ ABSL_GUARDED_BY(mutex_); - absl::flat_hash_map functions_ ABSL_GUARDED_BY(mutex_); + absl::flat_hash_map functions_ + ABSL_GUARDED_BY(mutex_); }; Kernel::Kernel(std::string kernel_name, uint32_t num_warps, @@ -536,11 +540,14 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { /*static*/ absl::StatusOr AutotunedKernelCall::Autotune( AutotunedKernelCall kernel_call, gpuStream_t stream, void** buffers) { // Ensure a valid context for driver calls that don't take the stream. - //gpuContext_t context; - //GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context)); - //GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); - //absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; - + // gpuContext_t context; + // GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context)); + // GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context)); + // absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); }; + + // If `stream` is in capture mode we can't run autotuning on it as we don't + // want to capture it into a graph. We create a new stream to do autotuning + // and destroy it when we are done. gpustreamCaptureStatus_t capture_status; GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status)); bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE; @@ -548,11 +555,20 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED; gpuStream_t autotune_stream = stream; + // An event that synchronizes autotuning stream with a main one. + gpuEvent_t autotune_event = nullptr; + if (is_capturing) { - GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode)); - // Need a side stream so as not to interfere with graph capture. - GPU_RETURN_IF_ERROR(gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING)); + + // Record event after completion of launched kernels on the main stream. + GPU_RETURN_IF_ERROR(gpuEventCreate(&autotune_event, 0)); + GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, stream)); + + // Create a new stream to run autotuning and synchronize it with main sream. + GPU_RETURN_IF_ERROR( + gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING)); + GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(autotune_stream, autotune_event)); } // If an input aliases with an output, it will get overwritten during the @@ -627,14 +643,22 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const { reinterpret_cast(buffers[input_idx]), input_copies[input_idx].data(), size, autotune_stream)); } + // Synchronize stream to ensure copies are complete before the host copy // is deleted. GPU_RETURN_IF_ERROR(gpuStreamSynchronize(autotune_stream)); if (is_capturing) { + // Wait on a main stream for completion of autotuning. + GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, autotune_stream)); + GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(stream, autotune_event)); + GPU_RETURN_IF_ERROR(gpuEventDestroy(autotune_event)); + + // Destroy autotuning stream and recover stream capturing mode. GPU_RETURN_IF_ERROR(gpuStreamDestroy(autotune_stream)); GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode)); } + return std::move(kernel_call.configs_[0].kernel_call); } diff --git a/jaxlib/gpu/vendor.h b/jaxlib/gpu/vendor.h index 880968380039..dcbb95e8b360 100644 --- a/jaxlib/gpu/vendor.h +++ b/jaxlib/gpu/vendor.h @@ -305,6 +305,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t; #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice #define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost #define gpuStreamSynchronize cudaStreamSynchronize +#define gpuStreamWaitEvent cudaStreamWaitEvent #define gpuSuccess cudaSuccess namespace jax::JAX_GPU_NAMESPACE { @@ -501,9 +502,9 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t; #define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT #define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS -#define GPU_STREAM_CAPTURE_STATUS_ACTIVE hipStreamCaptureStatusActive +#define GPU_STREAM_CAPTURE_STATUS_ACTIVE hipStreamCaptureStatusActive #define GPU_STREAM_CAPTURE_MODE_RELAXED hipStreamCaptureModeRelaxed -#define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking +#define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking #define gpuGetLastError hipGetLastError #define gpuGetErrorString hipGetErrorString