Skip to content

Commit

Permalink
[jax-triton] Synchronize autotuning stream with a main one
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 609792049
  • Loading branch information
ezhulenev authored and jax authors committed Feb 23, 2024
1 parent bb5997b commit 3a69b80
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 22 deletions.
64 changes: 44 additions & 20 deletions jaxlib/gpu/triton_kernels.cc
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
#include "jaxlib/gpu/triton_kernels.h"

#include <algorithm>
#include <cstdint>
#include <cmath>
#include <cstddef>
#include <cstdint>
#include <limits>
#include <memory>
#include <string>
#include <string_view>
#include <tuple>
#include <type_traits>
#include <unordered_map>
#include <utility>
#include <variant>
#include <vector>

#include "absl/base/optimization.h"
#include "absl/base/thread_annotations.h"
#include "absl/cleanup/cleanup.h"
#include "absl/container/flat_hash_map.h"
#include "absl/log/check.h"
Expand All @@ -26,11 +31,9 @@
#include "jaxlib/gpu/vendor.h"
#include "xla/service/custom_call_status.h"
#include "xla/stream_executor/gpu/asm_compiler.h"
#include "tsl/platform/env.h"

#define GPU_RETURN_IF_ERROR(expr) JAX_RETURN_IF_ERROR(JAX_AS_STATUS(expr))


namespace jax::JAX_GPU_NAMESPACE {
namespace {

Expand Down Expand Up @@ -82,9 +85,9 @@ absl::StatusOr<ModuleImage*> GetModuleImage(std::string kernel_name,
auto it = module_images.find(key);
if (it != module_images.end()) return it->second.get();

#ifdef JAX_GPU_HIP //For HIP/ROCM just read the hsaco file
#ifdef JAX_GPU_HIP // For HIP/ROCM just read the hsaco file
std::string result_blob;
std::string fname{ptx};
std::string fname{ptx};
TF_RETURN_IF_ERROR(
tsl::ReadFileToString(tsl::Env::Default(), fname, &result_blob));
std::vector<uint8_t> module_image(result_blob.begin(), result_blob.end());
Expand Down Expand Up @@ -230,22 +233,22 @@ class ModuleImage {
}

if (shared_optin > kMaxStaticSharedMemBytes) {
#ifdef JAX_GPU_CUDA
GPU_RETURN_IF_ERROR(
#ifdef JAX_GPU_CUDA
GPU_RETURN_IF_ERROR(
gpuFuncSetCacheConfig(function, CU_FUNC_CACHE_PREFER_SHARED));
#endif
#endif
int shared_total;
GPU_RETURN_IF_ERROR(gpuDeviceGetAttribute(
&shared_total,
GPU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device));
int shared_static;
GPU_RETURN_IF_ERROR(gpuFuncGetAttribute(
&shared_static, GPU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, function));
#ifdef JAX_GPU_CUDA
GPU_RETURN_IF_ERROR(cuFuncSetAttribute(
#ifdef JAX_GPU_CUDA
GPU_RETURN_IF_ERROR(cuFuncSetAttribute(
function, GPU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES,
shared_optin - shared_static));
#endif
#endif
}
return function;
}
Expand All @@ -257,7 +260,8 @@ class ModuleImage {

absl::Mutex mutex_;
std::vector<OwnedGPUmodule> modules_ ABSL_GUARDED_BY(mutex_);
absl::flat_hash_map<gpuContext_t, gpuFunction_t> functions_ ABSL_GUARDED_BY(mutex_);
absl::flat_hash_map<gpuContext_t, gpuFunction_t> functions_
ABSL_GUARDED_BY(mutex_);
};

Kernel::Kernel(std::string kernel_name, uint32_t num_warps,
Expand Down Expand Up @@ -536,23 +540,35 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
/*static*/ absl::StatusOr<KernelCall> AutotunedKernelCall::Autotune(
AutotunedKernelCall kernel_call, gpuStream_t stream, void** buffers) {
// Ensure a valid context for driver calls that don't take the stream.
//gpuContext_t context;
//GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
//GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
//absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };

// gpuContext_t context;
// GPU_RETURN_IF_ERROR(gpuStreamGetCtx(stream, &context));
// GPU_RETURN_IF_ERROR(gpuCtxPushCurrent(context));
// absl::Cleanup ctx_restorer = [] { gpuCtxPopCurrent(nullptr); };

// If `stream` is in capture mode we can't run autotuning on it as we don't
// want to capture it into a graph. We create a new stream to do autotuning
// and destroy it when we are done.
gpustreamCaptureStatus_t capture_status;
GPU_RETURN_IF_ERROR(gpuStreamIsCapturing(stream, &capture_status));
bool is_capturing = capture_status == GPU_STREAM_CAPTURE_STATUS_ACTIVE;

gpustreamCaptureMode_t capture_mode = GPU_STREAM_CAPTURE_MODE_RELAXED;
gpuStream_t autotune_stream = stream;

// An event that synchronizes autotuning stream with a main one.
gpuEvent_t autotune_event = nullptr;

if (is_capturing) {

GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
// Need a side stream so as not to interfere with graph capture.
GPU_RETURN_IF_ERROR(gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING));

// Record event after completion of launched kernels on the main stream.
GPU_RETURN_IF_ERROR(gpuEventCreate(&autotune_event, 0));
GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, stream));

// Create a new stream to run autotuning and synchronize it with main sream.
GPU_RETURN_IF_ERROR(
gpuStreamCreate(&autotune_stream, GPU_STREAM_NON_BLOCKING));
GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(autotune_stream, autotune_event));
}

// If an input aliases with an output, it will get overwritten during the
Expand Down Expand Up @@ -627,14 +643,22 @@ jax_triton::TritonAutotunedKernelCall AutotunedKernelCall::ToProto() const {
reinterpret_cast<gpuDevicePtr_t>(buffers[input_idx]),
input_copies[input_idx].data(), size, autotune_stream));
}

// Synchronize stream to ensure copies are complete before the host copy
// is deleted.
GPU_RETURN_IF_ERROR(gpuStreamSynchronize(autotune_stream));

if (is_capturing) {
// Wait on a main stream for completion of autotuning.
GPU_RETURN_IF_ERROR(gpuEventRecord(autotune_event, autotune_stream));
GPU_RETURN_IF_ERROR(gpuStreamWaitEvent(stream, autotune_event));
GPU_RETURN_IF_ERROR(gpuEventDestroy(autotune_event));

// Destroy autotuning stream and recover stream capturing mode.
GPU_RETURN_IF_ERROR(gpuStreamDestroy(autotune_stream));
GPU_RETURN_IF_ERROR(gpuThreadExchangeStreamCaptureMode(&capture_mode));
}

return std::move(kernel_call.configs_[0].kernel_call);
}

Expand Down
5 changes: 3 additions & 2 deletions jaxlib/gpu/vendor.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ typedef cusparseDnVecDescr_t gpusparseDnVecDescr_t;
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
#define gpuStreamSynchronize cudaStreamSynchronize
#define gpuStreamWaitEvent cudaStreamWaitEvent
#define gpuSuccess cudaSuccess

namespace jax::JAX_GPU_NAMESPACE {
Expand Down Expand Up @@ -501,9 +502,9 @@ typedef hipsparseDnVecDescr_t gpusparseDnVecDescr_t;
#define GPUSPARSE_SPARSETODENSE_ALG_DEFAULT HIPSPARSE_SPARSETODENSE_ALG_DEFAULT
#define GPUSPARSE_STATUS_SUCCESS HIPSPARSE_STATUS_SUCCESS

#define GPU_STREAM_CAPTURE_STATUS_ACTIVE hipStreamCaptureStatusActive
#define GPU_STREAM_CAPTURE_STATUS_ACTIVE hipStreamCaptureStatusActive
#define GPU_STREAM_CAPTURE_MODE_RELAXED hipStreamCaptureModeRelaxed
#define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking
#define GPU_STREAM_NON_BLOCKING hipStreamNonBlocking

#define gpuGetLastError hipGetLastError
#define gpuGetErrorString hipGetErrorString
Expand Down

0 comments on commit 3a69b80

Please sign in to comment.