Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[codegen] Add max(half, half) support when enable fp16 #3811

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/codegen/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ void CodeGenCUDA::AddFunction(LoweredFunc f) {
std::string CodeGenCUDA::Finish() {
if (enable_fp16_) {
decl_stream << "#include <cuda_fp16.h>\n";
decl_stream << "__device__ half max(const half a, const half b)\n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we know which operators we have to overload as such? "max" is one of them. Do we need others?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For now, I only find max that need to be overloaded.

BTW, I have a question about the checks.
Why this commit cannot be built today? It was successful yesterday.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, we saw more failures while trying to run full resnet

#3816 (comment)

I think, we are missing all reduce ops. Will it be possible for you to help with this? (In a separate PR, this one is good to go)

"{\n return __hgt(__half(a), __half(b)) ? a : b;\n}\n";
}

if (enable_int8_) {
Expand Down
33 changes: 33 additions & 0 deletions tests/python/unittest/test_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,38 @@ def test_rfactor_predicates():

fcuda = tvm.build(s, [A, B], "cuda")

def test_cuda_vector_max():
num_thread = 8
target = 'cuda'
def check_vector_max(ctx, n, dtype):
if not tvm.gpu(0).exist or not tvm.module.enabled("cuda"):
print("skip because cuda is not enabled..")
return
if dtype == "float16" and not have_fp16(tvm.gpu(0).compute_version):
print("skip because gpu does not support fp16")
return
A = tvm.placeholder((n,), name='A', dtype=dtype)
B = tvm.placeholder((n,), name='B', dtype=dtype)
C = tvm.compute((n,), lambda i: tvm.max(A[i], B[i]), name='C')
s = tvm.create_schedule(C.op)
bx, tx = s[C].split(C.op.axis[0], factor=num_thread)
s[C].bind(bx, tvm.thread_axis("blockIdx.x"))
s[C].bind(tx, tvm.thread_axis("threadIdx.x"))
fun = tvm.build(s, [A,B,C], "cuda", name="vector_max")

np_a = np.random.uniform(size=n).astype(dtype)
np_b = np.random.uniform(size=n).astype(dtype)
np_c = np.maximum(np_a, np_b)
a = tvm.nd.empty((n,), A.dtype, ctx).copyfrom(np_a)
b = tvm.nd.empty((n,), B.dtype, ctx).copyfrom(np_b)
c = tvm.nd.empty((n,), C.dtype, ctx)
fun(a, b, c)
np.testing.assert_equal(c.asnumpy(), np_c)

ctx = tvm.context(target, 0)
check_vector_max(ctx, 10, "float32")
check_vector_max(ctx, 10, "float16")


if __name__ == "__main__":
test_cuda_vectorize_add()
Expand All @@ -266,3 +298,4 @@ def test_rfactor_predicates():
test_cuda_shuffle()
test_cuda_reducition_binding()
test_rfactor_predicates()
test_cuda_vector_max()