Skip to content

Commit

Permalink
add memcpy_async
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Sep 13, 2023
1 parent 1d217c6 commit d026123
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,9 @@ __global__ void MatMulFloatInt4Kernel(

constexpr int stages_count = 2;
constexpr int elements_per_ite = 256;
__shared__ T a_shared[elements_per_ite * stages_count];
__shared__ alignas(alignof(float4)) T a_shared[elements_per_ite * stages_count];
auto group = cooperative_groups::this_thread_block();
T* a_shared_stages[stages_count] = {a_shared, a_shared + 2 * group.size()};
T* a_shared_stages[stages_count] = {a_shared, a_shared + elements_per_ite};

// Create a synchronization object (cuda::pipeline)
__shared__ nvidia_cuda::pipeline_shared_state<nvidia_cuda::thread_scope::thread_scope_block, stages_count> shared_state;

Check warning on line 112 in onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu#L112

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu:112:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
Expand All @@ -119,11 +119,12 @@ __global__ void MatMulFloatInt4Kernel(
// fetch from global to shared
for (; fetch < k_iter && fetch < (k_step + stages_count); fetch++) {
pipeline.producer_acquire();
nvidia_cuda::memcpy_async(group, a_shared_stages[fetch % 2], a_data + k_id, sizeof(T) * elements_per_ite, pipeline);
nvidia_cuda::memcpy_async(group, a_shared_stages[fetch % 2], a_data + fetch * elements_per_ite, sizeof(T) * elements_per_ite, pipeline);

Check warning on line 122 in onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu#L122

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cuda/quantization/matmul_with_quant_weight.cu:122:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
pipeline.producer_commit();
}

pipeline.consumer_wait();
__syncthreads();
T scale = b_scale_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
uint8_t zp = b_zp_vec[warp_id * group_count + (k_id + lane_id * 8) / group_size];
sum += AccumulateEightElements(value, scale, zp, a_shared_stages[k_step % 2] + (lane_id << 3));
Expand Down

0 comments on commit d026123

Please sign in to comment.