Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Relay] Handle float16 constants & fix BatchNorm #3260

Merged
merged 1 commit into from
May 31, 2019
Merged

[Relay] Handle float16 constants & fix BatchNorm #3260

merged 1 commit into from
May 31, 2019

Conversation

cbalint13
Copy link
Contributor

@cbalint13 cbalint13 commented May 30, 2019

This PR fixes:

  • constant expressions targeted as float16 in relay.
  • batchnorm module is not fixed to float32 anymore.
  • extend relay testcase to cover float16 too.

It was tested on real Yolo-Tiny net targeted as float16, it works well now including auto-tuning.

It is based on @anijain2305 suggestion from this discuss.tvm.ai thread.

@tqchen
Copy link
Member

tqchen commented May 30, 2019

Thanks for the contribution, please request reviews from Reviewers

@cbalint13
Copy link
Contributor Author

cc @jroesch @MarisaKirisame @wweic @icemelon9

Please help with review.

@@ -52,7 +53,6 @@ Expr BatchNormToInferUnpack(const Attrs attrs,
}

int axis = param->axis;
auto ttype = tdata.as<TensorTypeNode>();
CHECK(ttype);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's also move this CHECK to the beginning of the function

// convert to float16
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add an assertion that T is float32

Copy link
Contributor Author

@cbalint13 cbalint13 May 30, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vinx13

  • static_cast<float>(T value) will make sure __truncXfyf2() is feeding with float32.
  • unsure, still want to assert(T is_float32) or we let T to be any type ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit manipulation seems to bring a certain level of tediousness/ugliness in the TVM codebase.

A way to resolve this can be to keep the Float constants always in FP32. And insert a Relay cast operation to FP16 if need be. Then this will become part of fold_constant in Relay graph, and will use LLVM to generate the FP32 to FP16 conversion (hiding the bit manipulation).

I do not have any strong opinion on either. What do you think @vinx13 ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because we support arbitrary bit types in the IR, we should really use gmp and mpfr for storing literals. It is bad practice to store them in finite types, as truncation may occur and it is impossible to test whether truncation occurred programmatically w/o weird hacks. This is out of the scope of this PR unfortunately.

Copy link
Member

@vinx13 vinx13 May 31, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cbalint13 ignore my comments :) the static cast if sufficient here

I agree with @jroesch that we should do the pre-computation in higher precision, we can do this in the future

// convert to float16
// storage is uint16_t
*static_cast<DType*>(arr->data) =
__truncXfYf2__<float, uint32_t, 23, uint16_t, uint16_t, 10>(static_cast<float>(value));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This bit manipulation seems to bring a certain level of tediousness/ugliness in the TVM codebase.

A way to resolve this can be to keep the Float constants always in FP32. And insert a Relay cast operation to FP16 if need be. Then this will become part of fold_constant in Relay graph, and will use LLVM to generate the FP32 to FP16 conversion (hiding the bit manipulation).

I do not have any strong opinion on either. What do you think @vinx13 ?

@vinx13 vinx13 merged commit 584a32a into apache:master May 31, 2019
@vinx13
Copy link
Member

vinx13 commented May 31, 2019

Thanks @cbalint13 @wweic @jroesch @anijain2305 this is merged

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants