-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CODEGEN][CUDA] Fix vector load #5226
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -291,7 +291,7 @@ static inline __device__ __host__ unsigned | |
__pack_half2(const half x, const half y) { | ||
unsigned v0 = *((unsigned short *)&x); | ||
unsigned v1 = *((unsigned short *)&y); | ||
return (v0 << 16) | v1; | ||
return (v1 << 16) | v0; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! |
||
} | ||
)"; | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -543,6 +543,44 @@ def run_test(dtype): | |
run_test("uint32") | ||
run_test("uint64") | ||
|
||
def test_cuda_vectorize_load_permute_pad(): | ||
def check_cuda(dtype, n, l, padding, lanes): | ||
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): | ||
print("skip because cuda is not enabled..") | ||
return | ||
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version): | ||
print("Skip because gpu does not have fp16 support") | ||
return | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. check if float16 is supported There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Already checked. |
||
ctx = tvm.gpu(0) | ||
A = tvm.te.placeholder((n, l), name='A', dtype=dtype) | ||
B = tvm.te.compute((n // lanes, l + 2 * padding, lanes), | ||
lambda i, j, k: tvm.te.if_then_else( | ||
tvm.te.any(j < padding, j >= l + padding), | ||
tvm.runtime.convert(0).astype(dtype), A[i * lanes + k, j - padding]), | ||
name='B') | ||
s = te.create_schedule(B.op) | ||
block, thread, vectorize = s[B].op.axis | ||
s[B].bind(block, bx) | ||
s[B].bind(thread, tx) | ||
s[B].vectorize(vectorize) | ||
fun = tvm.build(s, [A, B], "cuda", name="vector_load_permute_pad") | ||
np_a = np.random.randint( | ||
low=-128, high=127, size=(n, l)).astype(A.dtype) | ||
a = tvm.nd.empty((n, l), A.dtype, ctx).copyfrom(np_a) | ||
b = tvm.nd.empty((n // lanes, l + padding * 2, lanes), B.dtype, ctx) | ||
fun(a, b) | ||
np_a_reshape = np_a.reshape(n // lanes, lanes, l).transpose(0, 2, 1) | ||
ref = np.pad(np_a_reshape, ((0, 0), (padding, padding), | ||
(0, 0)), mode='constant', constant_values=0) | ||
tvm.testing.assert_allclose(b.asnumpy(), ref) | ||
|
||
check_cuda("int8", 64, 16, 3, 4) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. uint8 test? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Already added uint8 test. |
||
check_cuda("uint8", 64, 16, 3, 4) | ||
check_cuda("int32", 64, 16, 3, 4) | ||
check_cuda("float16", 64, 16, 3, 4) | ||
check_cuda("float32", 64, 16, 3, 4) | ||
|
||
if __name__ == "__main__": | ||
test_cuda_vectorize_add() | ||
test_cuda_multiply_add() | ||
|
@@ -560,3 +598,4 @@ def run_test(dtype): | |
test_vectorized_intrin1() | ||
test_vectorized_intrin2() | ||
test_vectorized_popcount() | ||
test_cuda_vectorize_load_permute_pad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we care the signedness? this just downcasts to 32 bits,.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TVM uses uint to store unit8x4 (in function PrintType). The care will generate code like
unit x = (unit)y
, instead ofunit x = (int)y
. And what is your further opinion?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep it as is? I do not see benefits from this change. Otherwise the entire PR LGTM. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not necessary to revert this change, if it's harmless. Consider that
CodeGenCUDA::PrintType
for uint8x4 generates "uint", this change somehow makes sense.