Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix performance bugs in scalar reductions #509

Merged

Conversation

magnatelee
Copy link
Contributor

No description provided.

* Use unsigned 64-bit integers instead of signed integers wherever
  possible; CUDA hasn't added an atomic intrinsic for the latter yet.

* Move reduction buffers from zero-copy memory to framebuffer. This
  makes the slow atomic update code path in reduction operators
  run much more efficiently.
@magnatelee
Copy link
Contributor Author

this fixes #506 (cc @rohany)

@manopapad
Copy link
Contributor

Do you you want to also replace the use of DeferredReduction in binary_red.cu?

@magnatelee
Copy link
Contributor Author

Do you you want to also replace the use of DeferredReduction in binary_red.cu?

b204473

src/cunumeric/cuda_help.h Outdated Show resolved Hide resolved
CHECK_CUDA(cudaMemcpyAsync(ptr_, &identity, sizeof(LHS), cudaMemcpyHostToDevice, stream));
}

__device__ void operator<<=(const RHS& value) const
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: obviously these things come down to preference and this is just matching legion, but I would personally suggest writing this out as a function name rather than overloading an operator. This appears be doing an atomic reduce. The reduce_output helper function was a little bit difficult to parse with the <<= (a bit-shift operator borrowed for a different purpose) instead of just having a function call say exactly what the code is doing (result.non_exclusive_fold, e.g.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using RHS = typename REDOP::RHS;

public:
ScalarReductionBuffer(cudaStream_t stream) : buffer_(legate::create_buffer<LHS>(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the class name obviously gets annoying long, but consider calling this 'DeviceScalarReductionBuffer' to make it clear this is not a general reduction buffer and is only designed for device reductions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using RHS = typename REDOP::RHS;

public:
ScalarReductionBuffer(cudaStream_t stream) : buffer_(legate::create_buffer<LHS>(1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is again only going to run on the device, do we want to explicitly pass GPU_FB_MEM to create_buffer to make it clearer what is happening? Otherwise this is using the default kind = NO_MEMKIND, which seems potentially fragile to rely on get_executing_processor() returning TOC_PROC to allocate this in the right place.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@magnatelee magnatelee merged commit e65032b into nv-legate:branch-22.10 Aug 6, 2022
@magnatelee magnatelee deleted the fix-perf-bug-scalar-reduction branch August 6, 2022 00:58
sbak5 pushed a commit to sbak5/cunumeric that referenced this pull request Aug 17, 2022
* Unify the template for device reduction tree and do some cleanup

* Fix performance bugs in scalar reduction kernels:

* Use unsigned 64-bit integers instead of signed integers wherever
  possible; CUDA hasn't added an atomic intrinsic for the latter yet.

* Move reduction buffers from zero-copy memory to framebuffer. This
  makes the slow atomic update code path in reduction operators
  run much more efficiently.

* Use thew new scalar reduction buffer in binary reductions as well

* Use only the RHS type in the reduction buffer as we never call apply

* Minor clean up per review

* Rename the buffer class and method to make the intent explicit

* Flip the polarity of reduce's template parameter
marcinz pushed a commit to marcinz/cunumeric that referenced this pull request Aug 17, 2022
* Unify the template for device reduction tree and do some cleanup

* Fix performance bugs in scalar reduction kernels:

* Use unsigned 64-bit integers instead of signed integers wherever
  possible; CUDA hasn't added an atomic intrinsic for the latter yet.

* Move reduction buffers from zero-copy memory to framebuffer. This
  makes the slow atomic update code path in reduction operators
  run much more efficiently.

* Use thew new scalar reduction buffer in binary reductions as well

* Use only the RHS type in the reduction buffer as we never call apply

* Minor clean up per review

* Rename the buffer class and method to make the intent explicit

* Flip the polarity of reduce's template parameter
marcinz added a commit that referenced this pull request Aug 17, 2022
* Unify the template for device reduction tree and do some cleanup

* Fix performance bugs in scalar reduction kernels:

* Use unsigned 64-bit integers instead of signed integers wherever
  possible; CUDA hasn't added an atomic intrinsic for the latter yet.

* Move reduction buffers from zero-copy memory to framebuffer. This
  makes the slow atomic update code path in reduction operators
  run much more efficiently.

* Use thew new scalar reduction buffer in binary reductions as well

* Use only the RHS type in the reduction buffer as we never call apply

* Minor clean up per review

* Rename the buffer class and method to make the intent explicit

* Flip the polarity of reduce's template parameter

Co-authored-by: Wonchan Lee <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants