-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Relay] Handle float16 constants & fix BatchNorm #3260
Conversation
Thanks for the contribution, please request reviews from Reviewers |
cc @jroesch @MarisaKirisame @wweic @icemelon9 Please help with review. |
src/relay/pass/simplify_inference.cc
Outdated
@@ -52,7 +53,6 @@ Expr BatchNormToInferUnpack(const Attrs attrs, | |||
} | |||
|
|||
int axis = param->axis; | |||
auto ttype = tdata.as<TensorTypeNode>(); | |||
CHECK(ttype); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's also move this CHECK to the beginning of the function
// convert to float16 | ||
// storage is uint16_t | ||
*static_cast<DType*>(arr->data) = | ||
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please add an assertion that T
is float32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
static_cast<float>(T value)
will make sure __truncXfyf2() is feeding with float32.- unsure, still want to
assert(T is_float32)
or we let T to be any type ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bit manipulation seems to bring a certain level of tediousness/ugliness in the TVM codebase.
A way to resolve this can be to keep the Float constants always in FP32. And insert a Relay cast operation to FP16 if need be. Then this will become part of fold_constant in Relay graph, and will use LLVM to generate the FP32 to FP16 conversion (hiding the bit manipulation).
I do not have any strong opinion on either. What do you think @vinx13 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we support arbitrary bit types in the IR, we should really use gmp and mpfr for storing literals. It is bad practice to store them in finite types, as truncation may occur and it is impossible to test whether truncation occurred programmatically w/o weird hacks. This is out of the scope of this PR unfortunately.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@cbalint13 ignore my comments :) the static cast if sufficient here
I agree with @jroesch that we should do the pre-computation in higher precision, we can do this in the future
// convert to float16 | ||
// storage is uint16_t | ||
*static_cast<DType*>(arr->data) = | ||
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This bit manipulation seems to bring a certain level of tediousness/ugliness in the TVM codebase.
A way to resolve this can be to keep the Float constants always in FP32. And insert a Relay cast operation to FP16 if need be. Then this will become part of fold_constant in Relay graph, and will use LLVM to generate the FP32 to FP16 conversion (hiding the bit manipulation).
I do not have any strong opinion on either. What do you think @vinx13 ?
Thanks @cbalint13 @wweic @jroesch @anijain2305 this is merged |
This PR fixes:
It was tested on real Yolo-Tiny net targeted as float16, it works well now including auto-tuning.
It is based on @anijain2305 suggestion from this discuss.tvm.ai thread.