-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[CODEGEN][CUDA] Fix vector load #5226
Conversation
also cc @wpan11nv @ZihengJiang |
@@ -291,7 +291,7 @@ static inline __device__ __host__ unsigned | |||
__pack_half2(const half x, const half y) { | |||
unsigned v0 = *((unsigned short *)&x); | |||
unsigned v1 = *((unsigned short *)&y); | |||
return (v0 << 16) | v1; | |||
return (v1 << 16) | v0; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good catch!
src/target/source/codegen_c.cc
Outdated
void CodeGenC::PrintVecElemLoadExpr( | ||
DataType t, int i, const std::string& value, std::ostream& os) { | ||
CHECK_GT(t.lanes(), 1); | ||
if (t.is_int() && t.bits() == 8) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint8?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already supported unit8.
(0, 0)), mode='constant', constant_values=0) | ||
tvm.testing.assert_allclose(b.asnumpy(), ref) | ||
|
||
check_cuda("int8", 64, 16, 3, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint8 test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already added uint8 test.
src/target/source/codegen_cuda.cc
Outdated
void CodeGenCUDA::PrintVecElemLoadExpr( | ||
DataType t, int i, const std::string& value, std::ostream& os) { | ||
CHECK_GT(t.lanes(), 1); | ||
if (t.is_int() && t.bits() == 8) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
uint8_t
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already supported unit8.
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"): | ||
print("skip because cuda is not enabled..") | ||
return | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
check if float16 is supported
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Already checked.
Thanks for your review, @wpan11nv . The new commit supported uint8, and it also fixed the unit8 bug in the code generation of BroadcastNode. |
// make_int8x4 | ||
const int64_t *p = as_const_int(op->value); | ||
CHECK(p); | ||
int64_t v = *p & 0xFF; | ||
v = (v << 24) | (v << 16) | (v << 8) | v; | ||
os << "(int)" << v; | ||
if (op->dtype.is_uint()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we care the signedness? this just downcasts to 32 bits,.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TVM uses uint to store unit8x4 (in function PrintType). The care will generate code like unit x = (unit)y
, instead of unit x = (int)y
. And what is your further opinion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we keep it as is? I do not see benefits from this change. Otherwise the entire PR LGTM. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's not necessary to revert this change, if it's harmless. Consider that CodeGenCUDA::PrintType
for uint8x4 generates "uint", this change somehow makes sense.
@vinx13 could you help to review the code? Thanks! |
Thanks @huochaitiantang @wpan11nv this is merged |
* Fix high-low bit bug in __pack_half2 * Fix vector load * Add unit8 support for PrintVecElemLoadExpr and BroadcastNode
* Fix high-low bit bug in __pack_half2 * Fix vector load * Add unit8 support for PrintVecElemLoadExpr and BroadcastNode
* Fix high-low bit bug in __pack_half2 * Fix vector load * Add unit8 support for PrintVecElemLoadExpr and BroadcastNode
The above code is a padding kernel. Whether
_2.x
,_2.y
,_2.z
,_2.w
are the correct indexes ofA
or not, the introduced variable_1
will be calculated. So emit the following code instead:@vinx13, could you please help review? Thanks!