-
Notifications
You must be signed in to change notification settings - Fork 64
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[oneCCL] enable bf16 comm. #73
Conversation
@@ -78,7 +78,9 @@ class bfloat16_t { | |||
} | |||
|
|||
static void cvt_float_to_bfloat16(const float *src, bfloat16_t *dst, int size); | |||
static void batch_cvt_float_to_bfloat16(const float *src, bfloat16_t *dst, int size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@changqi1 Do you accept adding it in bfloat16.h?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pujiang2018 @marvin-Yu Sure. It is great to add batch cvt. But the number 1024 is not make sense.
bf16_enable = (getenv("XFT_ONECCL_BF16") ? atoi(getenv("XFT_ONECCL_BF16")) : 0); | ||
if (bf16_enable) { | ||
printf("got 'XFT_ONECCL_BF16=%d', enable BF16 dtype comm.\n", bf16_enable); | ||
buf_bf16 = new bfloat16_t[MAX_BF16_BUFFER]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggest using aligned_alloc or malloc.
as new will initialize all the data to 0, which is not needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add this macro XFT_ONECCL_BF16 into wiki.
TimeLine t("Messenger.reduceAdd"); | ||
if (bf16_enable) { | ||
bfloat16_t::batch_cvt_float_to_bfloat16(sendBuf, buf_bf16, count); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if count > MAX_BF16_BUFFER?
if (bf16_enable) { | ||
bfloat16_t::batch_cvt_float_to_bfloat16(sendBuf, buf_bf16, count); | ||
ccl::allreduce( | ||
buf_bf16, buf_bf16, count, ccl::datatype::bfloat16, ccl::reduction::sum, *pcomm).wait(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here the code looks like not formatted.
@@ -175,7 +201,8 @@ class Messenger { | |||
int size; | |||
int rank; | |||
bool local_ranks_flag; | |||
|
|||
bfloat16_t* buf_bf16 = nullptr; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's initialize it in the constructor to make sure the same style.
@@ -128,6 +130,18 @@ inline void bfloat16_t::cvt_float_to_bfloat16(const float *src, bfloat16_t *dst, | |||
} | |||
} | |||
|
|||
inline void bfloat16_t::batch_cvt_float_to_bfloat16(const float *src, bfloat16_t *dst, int count){ | |||
constexpr int sizePerSplit = 1024; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe use template to define blockSize or use totalSize/totalSystemThreads to define the blockSize.
bf16_enable = (getenv("XFT_ONECCL_BF16") ? atoi(getenv("XFT_ONECCL_BF16")) : 0); | ||
if (bf16_enable) { | ||
printf("got 'XFT_ONECCL_BF16=%d', enable BF16 dtype comm.\n", bf16_enable); | ||
buf_bf16 = new bfloat16_t[MAX_BF16_BUFFER]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add this macro XFT_ONECCL_BF16 into wiki.
bf16_enable = (getenv("XFT_ONECCL_BF16") ? atoi(getenv("XFT_ONECCL_BF16")) : 0); | ||
if (bf16_enable) { | ||
printf("got 'XFT_ONECCL_BF16=%d', enable BF16 dtype comm.\n", bf16_enable); | ||
buf_bf16 = new bfloat16_t[MAX_BF16_BUFFER]; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Init buf_bf16 with max length may cause performance communication issue, please double confirm this.
Enable communication using XFT_ONECCL_BF16 for bf16 data type.