Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(bgmv): write shared_memory y_warpsize only when threadIdx.x == 0 #51

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

menggeliu1205
Copy link

should add threadIdx.x == 0, when you want to write y_warpsize. Otherwise it will lead the wrong answer.

Copy link
Contributor

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@menggeliu1205 thanks for the bugfix and it looks good to me.

An alternative fix is to add a sum = __shfl_sync(0xffffffff, sum, 0); before y_warpwise[threadIdx.y] ... to broadcast the reduction result from lane 0 to all lanes in the warp, either fix should work.

@menggeliu1205
Copy link
Author

thanks for replying and offer the other fix! i get it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants