Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add half precision float16 data type #5716

Merged
merged 19 commits into from
Dec 6, 2017
Merged

Conversation

kexinzhao
Copy link
Contributor

No description provided.

@kexinzhao kexinzhao changed the title Add half precision float16 data type [WIP] Add half precision float16 data type Nov 17, 2017
@kexinzhao kexinzhao changed the title [WIP] Add half precision float16 data type Add half precision float16 data type Nov 20, 2017

#include <cstdint>

#include <cuda.h>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need #ifdef PADDLE_WITH_CUDA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Will fix.


namespace fp16_impl {
// Convert from float to half precision in round-to-nearest-even mode
PADDLE_HOSTDEVICE inline float16 float_to_half_rn(float f);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe float_to_half_rn is better as a member of class float16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// float16_t is an alias for __fp16 in arm_fp16.h,
// which is included in arm_neon.h.
PADDLE_HOSTDEVICE inline float16(const float16_t& h) {
float16_t tmp = h;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this assignment statement can be removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, will fix.


PADDLE_HOSTDEVICE inline explicit float16(bool b) : x(b ? 0x3c00 : 0) {}

PADDLE_HOSTDEVICE inline explicit float16(int8_t val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 125-173 can use templates to simplify.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will fix.

#endif

#ifdef PADDLE_ARM
#ifdef __F16C__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe line 70-72 can be removed.
ARM environment does not seem to define F16C macro.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Will fix.

return *reinterpret_cast<float16*>(&tmp);

#elif defined(PADDLE_NEON_64)
float16 res;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use vcvt_f16_f32 and vget_lane_f16.
I think this can avoid writing two pieces of code for NEON_64 and NEON_32.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great point! Will fix.


// On ARMv8.2-A CPU
#elif defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
(PADDLE_GNUC_VER >= 71 || PADDLE_CLANG_VER >= 39)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use of undeclared identifier 'vaddh_f16'

I did not found arm_fp16.h in the android-ndk-r15c which the clang compiler version is 5.0.
Did I miss something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently clang does not support armv8.2 float16 neon intrinsics. (The currently developing clang 6.0 is planning to add this support). So I add assembly code for float16 arithmetic operators on the armv8.2 architecture, which should work for both gcc and clang.


PADDLE_HOSTDEVICE inline float16(const Eigen::half& h) : x(h.x) {}

#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is PADDLE_ARM_FP16 define?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ARM_FP16 is not defined. It is intended for the build system to define it when it detects that the current CPU is ARM v8.2a (please check this comment #4853 (comment)).

In that comment, the ARM compute library use SCons as build tool and define ARM_COMPUTE_ENABLE_FP16 when the right arm arch 8.2 is found. I want cmake to do similar things to PADDLE_ARM_FP16. However, I didn't find a way. @hedaoyuan Do you know how to do that?

Copy link
Contributor

@hedaoyuan hedaoyuan Nov 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can define -DPADDLE_ARM_FP16 in cmake when the architecture is specified as ARMv8.2.
Like this https://github.com/PaddlePaddle/Paddle/blob/develop/cmake/configure.cmake#L24

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. I am using the code below to specify -DPADDLE_ARM_FP16:

if(WITH_ARM_FP16)
  add_definitions(-DPADDLE_ARM_FP16)
  add_definitions("-march=armv8.2-a+fp16+simd")
endif(WITH_ARM_FP16)

}

__host__ inline bool operator<(const float16& a, const float16& b) {
#ifdef PADDLE_NEON_64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is still need PADDLE_NEON_64 here?
This code is under the macro of PADDLE_ARM_FP16.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PADDLE_ARM_FP16 is CPU Architecture related (intend to be defined when ARMv8.2A arch is found).
PADDLE_NEON_64 is more about the execution state of ARMv8.2A, because I believe ARMv8.2A CPU can run either in 32bit (when arm is defined) or 64 bit (when aarch64 is defined). GCC provides different sets of ARM intrinsics for arm and aarch64. That's why I define PADDLE_NEON_64 here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, when I specify the architecture as ARMv8.2(for those float16 instructions), can I still compile a 32-bit program?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.

An ARMv8 cpu only runs on arm-32bit state when the operating system is 32bit, which is the case for Raspberry Pi 3 model B.

I don't think anyone would run a 32bit OS on a ARMv8.2 cpu. So I will delete PADDLE_NEON_64

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A 32bit OS can run on an ARMv8.2 cpu. Also, a 32bit program can run on a 64bit OS(on ARMv8.2 cpu).
My point is, when you compile a program that uses the float16 instruction, it may only be compiled into a 64-bit program.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To reflect this, current code assumes 64-bit compilation when PADDLE_ARM_FP16 is defined.


// Arithmetic operators
#if defined(PADDLE_CUDA_FP16) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 530
__device__ inline float16 operator+(const float16& a, const float16& b) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to define these CUDA device operations?
We can use the half type directly in CUDA's Kernel.

Copy link
Contributor Author

@kexinzhao kexinzhao Nov 22, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please refer to https://github.com/PaddlePaddle/Paddle/pull/5851/files for the different support provided by different versions of CUDA to half type.

By defining these CUDA device operations here along with the implicit conversion operator between our float16 and half, we can run the following code on CUDA < 9.0 (have tested on our nvidia-docker image with CUDA 8.0):

namespace paddle {
__global__ void() {
  half a, b, c;
  // correct for cuda >= 7.5 if defined inside paddle namespace
  // gives compiler error if not put in paddle namespace for cuda < 9.0
  c = a + b;
}
}

So these device operations make our code using cuda half data type arithmetic operations easy to write and compatible with all CUDA >= 7.5.

However, if we call c = a + b with a, b, c all being half data type, it is much less efficient compared to c = __hadd(a, b) because of all the unnecessary conversions performed.

So I think we should instead add the following code in paddle namespace (add operation for example):

__device__ inline half operator+(const half& a, const half& b) { 
    return __hadd(a, b); 
}

This way c = a + b works on GPU for any CUDA >= 7.5 (for CUDA 9.0, this paddle::operator+ will be preferred over the counterpart in the global namespace because of name hiding in nested scope).

What do you think @hedaoyuan ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

However, if we call c = a + b with a, b, c all being half data type, it is much less efficient compared to c = __hadd(a, b) because of all the unnecessary conversions performed.

Yeah, implicit conversion is dangerous, the declaration of the conversion function needs to add explicit.

Copy link
Contributor

@hedaoyuan hedaoyuan Nov 23, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, for now, there is two way to implementations for the FP16 Kernel in CUDA.

  1. Use half and operator like c = __hadd(a, b) when CUDA < 9.0, this kernel is also work when CUDA >= 9.0
  2. Use half and operator like c = a + b when CUDA >= 9.0;

For the first, if I am an outside contributor, I am not familiar with the type definition of paddle.
For the second, if we need those kernels write for CUDA >= 9.0 work well when CUDA < 9.0, I think the opinion in your comment is better(define operator+ for cuda half when CUDA < 9.0).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Added half arithmetic operators for CUDA >= 7.5 and < 9.0.


#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
(PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
PADDLE_HOSTDEVICE inline float16& operator=(const float16_t& rhs) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I noticed that the following pattern

#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)

appeared in this file for three times. Should we define a new macro to improve the readability?

#if defined(PADDLE_NEON) && defined(PADDLE_ARM_FP16) && \
    (PADDLE_GNUC_VER >= 61 || PADDLE_CLANG_VER >= 34)
#  define PADDLE_WITH_NATIVE_FP16 
#endif

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Will do.

@reyoung reyoung requested review from qingqing01 and removed request for reyoung November 28, 2017 04:12
#endif // __clang__

#ifdef __CUDACC__
#define PADDLE_HOSTDEVICE __host__ __device__
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Done.

Copy link
Contributor

@hedaoyuan hedaoyuan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

赞~~

@kexinzhao kexinzhao merged commit 1d1555e into PaddlePaddle:develop Dec 6, 2017
@kexinzhao kexinzhao deleted the float16 branch December 6, 2017 08:26
@gongweibao gongweibao added the AMP label Feb 10, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants