Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CODEGEN][CUDA] Fix vector load #5226

Merged
merged 3 commits into from
Apr 14, 2020
Merged

[CODEGEN][CUDA] Fix vector load #5226

merged 3 commits into from
Apr 14, 2020

Conversation

huochaitiantang
Copy link
Contributor

  • Fix high-low bit bug in __pack_half2.
  • Do not emit code of vector load by introducing an extra statement and vector store:
    int _1;
    int4 _2 = (make_int4)(
      ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) - 3))+(16*0), 
      ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) - 3))+(16*1), 
      ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) - 3))+(16*2), 
      ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) - 3))+(16*3));
    
    _1=(((signed char*)A)[_2.x] << 0);
    _1=_1 & ~(0x000000ff << 8) |(((signed char*)A)[_2.y] << 8);
    _1=_1 & ~(0x000000ff << 16) |(((signed char*)A)[_2.z] << 16);
    _1=_1 & ~(0x000000ff << 24) |(((signed char*)A)[_2.w] << 24);
    (( int*)(( signed char*)B + (((((int)blockIdx.x) * 88) + (((int)threadIdx.x) * 4)))))[0] = 
    (((((int)threadIdx.x) < 3) || (19 <= ((int)threadIdx.x))) ? (int)0 : _1);

The above code is a padding kernel. Whether _2.x, _2.y, _2.z, _2.w are the correct indexes of A or not, the introduced variable _1 will be calculated. So emit the following code instead:

  int4 _1 = (make_int4)(
    ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) - 3))+(16*0), 
    ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) - 3))+(16*1), 
    ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) - 3))+(16*2), 
    ((((((int)blockIdx.x) * 64) + ((int)threadIdx.x)) - 3))+(16*3));

  (( int*)(( signed char*)B + (((((int)blockIdx.x) * 88) + (((int)threadIdx.x) * 4)))))[0] = 
    (((((int)threadIdx.x) < 3) || (19 <= ((int)threadIdx.x))) 
    ? (int)0 
    : ((0x000000ff << 0) & (((signed char*)A)[_1.x] << 0))|
      ((0x000000ff << 8) & (((signed char*)A)[_1.y] << 8))|
      ((0x000000ff << 16) & (((signed char*)A)[_1.z] << 16))|
      ((0x000000ff << 24) & (((signed char*)A)[_1.w] << 24)));

@vinx13, could you please help review? Thanks!

@tqchen
Copy link
Member

tqchen commented Apr 4, 2020

also cc @wpan11nv @ZihengJiang

@@ -291,7 +291,7 @@ static inline __device__ __host__ unsigned
__pack_half2(const half x, const half y) {
unsigned v0 = *((unsigned short *)&x);
unsigned v1 = *((unsigned short *)&y);
return (v0 << 16) | v1;
return (v1 << 16) | v0;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

void CodeGenC::PrintVecElemLoadExpr(
DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1);
if (t.is_int() && t.bits() == 8) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint8?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already supported unit8.

(0, 0)), mode='constant', constant_values=0)
tvm.testing.assert_allclose(b.asnumpy(), ref)

check_cuda("int8", 64, 16, 3, 4)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint8 test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already added uint8 test.

void CodeGenCUDA::PrintVecElemLoadExpr(
DataType t, int i, const std::string& value, std::ostream& os) {
CHECK_GT(t.lanes(), 1);
if (t.is_int() && t.bits() == 8) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

uint8_t

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already supported unit8.

if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

check if float16 is supported

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Already checked.

@huochaitiantang
Copy link
Contributor Author

Thanks for your review, @wpan11nv . The new commit supported uint8, and it also fixed the unit8 bug in the code generation of BroadcastNode.

// make_int8x4
const int64_t *p = as_const_int(op->value);
CHECK(p);
int64_t v = *p & 0xFF;
v = (v << 24) | (v << 16) | (v << 8) | v;
os << "(int)" << v;
if (op->dtype.is_uint()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we care the signedness? this just downcasts to 32 bits,.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TVM uses uint to store unit8x4 (in function PrintType). The care will generate code like unit x = (unit)y, instead of unit x = (int)y. And what is your further opinion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep it as is? I do not see benefits from this change. Otherwise the entire PR LGTM. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's not necessary to revert this change, if it's harmless. Consider that CodeGenCUDA::PrintType for uint8x4 generates "uint", this change somehow makes sense.

@huochaitiantang
Copy link
Contributor Author

@vinx13 could you help to review the code? Thanks!

@vinx13 vinx13 merged commit d2e58ad into apache:master Apr 14, 2020
@vinx13
Copy link
Member

vinx13 commented Apr 14, 2020

Thanks @huochaitiantang @wpan11nv this is merged

trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
* Fix high-low bit bug in __pack_half2

* Fix vector load

* Add unit8 support for PrintVecElemLoadExpr and BroadcastNode
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
* Fix high-low bit bug in __pack_half2

* Fix vector load

* Add unit8 support for PrintVecElemLoadExpr and BroadcastNode
dpankratz pushed a commit to dpankratz/incubator-tvm that referenced this pull request Apr 24, 2020
* Fix high-low bit bug in __pack_half2

* Fix vector load

* Add unit8 support for PrintVecElemLoadExpr and BroadcastNode
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants