-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Streamk v0.2 #646
base: main_perf
Are you sure you want to change the base?
Streamk v0.2 #646
Conversation
Additionally changes locks to use uint8 instead of int32 for smaller space footprint.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I put in some comments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, for gfx90a the load/stores with cache_modifiers
do not work. Documented here: https://github.com/ROCm/triton-internal/issues/311
rm1 = tl.max_contiguous(tl.multiple_of(rm1, BLOCK_SIZE_M), BLOCK_SIZE_M) | ||
rn1 = tl.max_contiguous(tl.multiple_of(rn1, BLOCK_SIZE_N), BLOCK_SIZE_N) | ||
P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :] | ||
tl.store(P_, acc, cache_modifier=".wt") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note, for gfx90a the load/stores with cache_modifiers
do not work. Documented here: https://github.com/ROCm/triton-internal/issues/311
# todo: try use tl.load once cache modifier landed upstream | ||
while tl.atomic_cas(locks + next_pid, 1, 1) != 1: | ||
while (end < tile_iter_end and next_pid < NUM_SMS): | ||
while tl.load(locks + next_pid, cache_modifier=".cv", volatile=True) != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This also does not work in gfx90a
: https://github.com/ROCm/triton-internal/issues/311
EVEN_K: tl.constexpr, | ||
): | ||
pid = tl.program_id(0) | ||
pid = get_new_pid(pid, num_cus) | ||
pid = (pid % 8) * (NUM_SMS // 8) + (pid // 8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed for anything but gfx942
, so we will actually remove this if the arch was gfx90a
.
P = torch.zeros((num_cus, block_m * block_n), device="cuda", dtype=torch.float32) | ||
triton_output = matmul(a, b, c, P, locks, num_cus, block_m, block_n, block_k, group_m, num_warps, num_stages, | ||
waves_per_eu, mfmaInstrSize, kpack, EVEN_K) | ||
locks = torch.zeros((num_sms, ), device="cuda", dtype=torch.int32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
locks
can be less than int32
type, we only need 1 byte: uint8
should work.
streamk v0.2:
new streamk tuning script to reduce compiling and profiling time
use load/store cache modifier to reimplement spinning lock
add CI test for streamk-kernel
able to use streampipelineV2