Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add bfloat16 support for CPU #5497

Merged
merged 6 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,11 +203,11 @@ endif(NOT MSVC)
# Compile LIBXSMM
if((NOT MSVC) AND USE_LIBXSMM)
if(REBUILD_LIBXSMM)
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0
add_custom_target(libxsmm COMMAND make realclean COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0 CC=${CMAKE_C_COMPILER}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
else(REBUILD_LIBXSMM)
add_custom_target(libxsmm COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0
add_custom_target(libxsmm COMMAND make -j ECFLAGS="-Wno-error=deprecated-declarations" BLAS=0 CC=${CMAKE_C_COMPILER}
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/third_party/libxsmm
)
endif(REBUILD_LIBXSMM)
Expand Down
79 changes: 53 additions & 26 deletions include/dgl/aten/macro.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,42 +152,69 @@
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef __nv_bfloat16 FloatType; \
{ __VA_ARGS__ } \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
LOG(FATAL) << (val_name) << " can't be float16 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
frozenbugs marked this conversation as resolved.
Show resolved Hide resolved
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/bfloat16/float32/float64 on GPU"; \
} \
} while (0)
#else // BF16_ENABLED
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if (XPU == kDGLCPU) { \
LOG(FATAL) << (val_name) << " can only be float32 or float64 on CPU"; \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLFloat) { \
typedef __half FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCUDA && (val).bits == 16 && (val).code == kDGLBfloat) { \
LOG(FATAL) << "bfloat16 requires CUDA >= 11.0"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLFloat) { \
LOG(FATAL) << (val_name) << " can't be float16 on CPU"; \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be float16/float32/float64 on GPU"; \
} \
} while (0)
#endif // BF16_ENABLED
#else // DGL_USE_CUDA
#define ATEN_FLOAT_TYPE_SWITCH_16BITS(val, FloatType, XPU, val_name, ...) \
ATEN_FLOAT_TYPE_SWITCH(val, FloatType, val_name, {__VA_ARGS__})
do { \
CHECK((val).code == kDGLFloat || (val.code == kDGLBfloat)) \
<< (val_name) << " must be float type"; \
if ((val).bits == 32) { \
typedef float FloatType; \
{ __VA_ARGS__ } \
} else if ((val).bits == 64) { \
typedef double FloatType; \
{ __VA_ARGS__ } \
} else if ( \
XPU == kDGLCPU && (val).bits == 16 && (val).code == kDGLBfloat) { \
typedef BFloat16 FloatType; \
{ __VA_ARGS__ } \
} else { \
LOG(FATAL) << (val_name) \
<< " can only be bfloat16/float32/float64 on CPU"; \
} \
} while (0)
#endif // DGL_USE_CUDA

/**
Expand Down
68 changes: 68 additions & 0 deletions include/dgl/runtime/bfloat16.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
/**
* Copyright (c) 2023 by Contributors
* @file dgl/runtime/ndarray.h
* @brief BFloat16 CPU header
*/
#ifndef DGL_RUNTIME_BFLOAT16_H_
#define DGL_RUNTIME_BFLOAT16_H_

#include <cmath>

class BFloat16 {
uint16_t val;

public:
constexpr BFloat16() : val(0) {}
// Disable lint "explicit" warning, since implicit usage on constructor is
// expected.
BFloat16(float f) { // NOLINT
if (std::isnan(f)) {
val = 0x7FC0;
} else {
union {
uint16_t iraw16[2];
uint32_t iraw32;
float f32;
};

f32 = f;
const uint32_t rounding_bias = 0x00007FFF + (iraw16[1] & 0x1);
val = static_cast<uint16_t>((iraw32 + rounding_bias) >> 16);
}
}
static constexpr BFloat16 Min() {
BFloat16 min;
min.val = 0xFF80;
return min;
}

static constexpr BFloat16 Max() {
BFloat16 max;
max.val = 0x7F80;
return max;
}

BFloat16& operator-=(const float& rhs) {
float lhs = (*this);
(*this) = lhs - rhs;
return *this;
}

BFloat16& operator+=(const float& rhs) {
float lhs = (*this);
(*this) = lhs + rhs;
return *this;
}

operator float() const {
union {
float f;
uint16_t raw[2];
};
raw[0] = 0;
raw[1] = val;
Comment on lines +62 to +63
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it assuming big endian byte order? If it's little endian I think this should be

raw[0] = val;

Also, could we have a test case on converting between bfloat16 and float32?

return f;
}
};

#endif // DGL_RUNTIME_BFLOAT16_H_
1 change: 1 addition & 0 deletions include/dgl/runtime/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <utility>
#include <vector>

#include "bfloat16.h"
#include "c_runtime_api.h"
#include "serializer.h"
#include "shared_mem.h"
Expand Down
2 changes: 1 addition & 1 deletion python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def config_cython():
library_dirs=library_dirs,
libraries=libraries,
# Crashes without this flag with GCC 5.3.1
extra_compile_args=["-std=c++11"],
extra_compile_args=["-std=c++14"],
language="c++",
)
)
Expand Down
22 changes: 22 additions & 0 deletions src/array/cpu/gather_mm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ void GatherMMScatter(
LOG(FATAL) << "Unsupported CPU kernel for GatherMM.";
}

template void GatherMM<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
template void GatherMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);
Expand All @@ -53,6 +59,12 @@ template void GatherMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b);

template void GatherMMScatter<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
template void GatherMMScatter<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);
Expand All @@ -66,6 +78,12 @@ template void GatherMMScatter<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray idx_a,
const NDArray idx_b, const NDArray idx_c);

template void SegmentMM<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
template void SegmentMM<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);
Expand All @@ -79,6 +97,10 @@ template void SegmentMM<kDGLCPU, int64_t, double>(
const NDArray A, const NDArray B, NDArray C, const NDArray seglen_A,
bool a_trans, bool b_trans);

template void SegmentMMBackwardB<kDGLCPU, int32_t, BFloat16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, BFloat16>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int32_t, float>(
const NDArray A, const NDArray dC, NDArray dB, const NDArray seglen);
template void SegmentMMBackwardB<kDGLCPU, int64_t, float>(
Expand Down
36 changes: 36 additions & 0 deletions src/array/cpu/sddmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,12 @@ void SDDMMCsrHetero(
});
}

template void SDDMMCsr<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCsr<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
Expand All @@ -91,6 +97,18 @@ template void SDDMMCsr<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const CSRMatrix& csr,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);

template void SDDMMCsrHetero<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCsrHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<CSRMatrix>& vec_csr, const std::vector<NDArray>& lhs,
Expand Down Expand Up @@ -152,6 +170,12 @@ void SDDMMCooHetero(
});
}

template void SDDMMCoo<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
template void SDDMMCoo<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);
Expand All @@ -165,6 +189,18 @@ template void SDDMMCoo<kDGLCPU, int64_t, double>(
const std::string& op, const BcastOff& bcast, const COOMatrix& coo,
NDArray lhs, NDArray rhs, NDArray out, int lhs_target, int rhs_target);

template void SDDMMCooHetero<kDGLCPU, int32_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int64_t, BFloat16>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
const std::vector<NDArray>& rhs, std::vector<NDArray> out, int lhs_target,
int rhs_target, const std::vector<dgl_type_t>& in_eid,
const std::vector<dgl_type_t>& out_eid);
template void SDDMMCooHetero<kDGLCPU, int32_t, float>(
const std::string& op, const BcastOff& bcast,
const std::vector<COOMatrix>& vec_coo, const std::vector<NDArray>& lhs,
Expand Down
34 changes: 34 additions & 0 deletions src/array/cpu/segment_reduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ void BackwardSegmentCmp(NDArray feat, NDArray arg, NDArray out) {
cpu::BackwardSegmentCmp<IdType, DType>(feat, arg, out);
}

template void SegmentReduce<kDGLCPU, int32_t, BFloat16>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int64_t, BFloat16>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
template void SegmentReduce<kDGLCPU, int32_t, float>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);
Expand All @@ -69,6 +75,16 @@ template void SegmentReduce<kDGLCPU, int64_t, double>(
const std::string& op, NDArray feat, NDArray offsets, NDArray out,
NDArray arg);

template <>
void ScatterAdd<kDGLCPU, int32_t, BFloat16>(
NDArray feat, NDArray idx, NDArray out) {
LOG(FATAL) << "Unsupported CPU kernel for ScatterAdd for BF16.";
}
template <>
void ScatterAdd<kDGLCPU, int64_t, BFloat16>(
NDArray feat, NDArray idx, NDArray out) {
LOG(FATAL) << "Unsupported CPU kernel for ScatterAdd for BF16.";
}
template void ScatterAdd<kDGLCPU, int32_t, float>(
NDArray feat, NDArray idx, NDArray out);
template void ScatterAdd<kDGLCPU, int64_t, float>(
Expand All @@ -78,6 +94,20 @@ template void ScatterAdd<kDGLCPU, int32_t, double>(
template void ScatterAdd<kDGLCPU, int64_t, double>(
NDArray feat, NDArray arg, NDArray out);

template <>
void UpdateGradMinMax_hetero<kDGLCPU, int32_t, BFloat16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
LOG(FATAL) << "Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.";
}
template <>
void UpdateGradMinMax_hetero<kDGLCPU, int64_t, BFloat16>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out) {
LOG(FATAL) << "Unsupported CPU kernel for UpdateGradMinMax_hetero for BF16.";
}
template void UpdateGradMinMax_hetero<kDGLCPU, int32_t, float>(
const HeteroGraphPtr& g, const std::string& op,
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
Expand All @@ -95,6 +125,10 @@ template void UpdateGradMinMax_hetero<kDGLCPU, int64_t, double>(
const std::vector<NDArray>& feat, const std::vector<NDArray>& idx,
const std::vector<NDArray>& idx_etype, std::vector<NDArray>* out);

template void BackwardSegmentCmp<kDGLCPU, int32_t, BFloat16>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, BFloat16>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int32_t, float>(
NDArray feat, NDArray arg, NDArray out);
template void BackwardSegmentCmp<kDGLCPU, int64_t, float>(
Expand Down
2 changes: 2 additions & 0 deletions src/array/cpu/segment_reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ namespace cpu {
*/
template <typename IdType, typename DType>
void SegmentSum(NDArray feat, NDArray offsets, NDArray out) {
if (std::is_same<DType, BFloat16>::value)
LOG(FATAL) << "Unsupported CPU kernel for SegmentSum for BF16.";
int n = out->shape[0];
int dim = 1;
for (int i = 1; i < out->ndim; ++i) dim *= out->shape[i];
Expand Down
Loading