Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CodeGen][CUDA] Enhance CUDA codegen for SelectNode #4983

Merged
merged 1 commit into from
Mar 11, 2020

Conversation

wpan11nv
Copy link
Contributor

@wpan11nv wpan11nv commented Mar 4, 2020

  • This patch allows CUDA backend to emit correct code for
    selects with vector conditions, which may be produced
    by floordiv op lowering etc..

  • This already works for llvm BE, as llvm select instruction
    supports vector conditions.

Signed-off-by: Wei Pan [email protected]

Thanks for contributing to TVM! Please refer to guideline https://docs.tvm.ai/contribute/ for useful information and tips. After the pull request is submitted, please request code reviews from Reviewers by @ them in the pull request thread.

@wpan11nv
Copy link
Contributor Author

wpan11nv commented Mar 4, 2020

The test fails without this patch. It is also exposed by #4968, in which a simple kernel fails during the CUDA codegen

// attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 98
// attr [iter_var(threadIdx.x, , threadIdx.x)] thread_extent = 1024
T_relu[ramp(((blockIdx.x2048) + (threadIdx.x2)), 1, 2)] = max((placeholder[ramp(((blockIdx.x2048) + (threadIdx.x2)), 1, 2)] + placeholder[floordiv(ramp(((blockIdx.x2048) + (threadIdx.x2)), 1, 2), x2(3136))]), x2(0f))

@wpan11nv wpan11nv force-pushed the select_cg branch 2 times, most recently from 7c48a26 to 290c58a Compare March 4, 2020 19:46
- This patch allows CUDA backend to emit correct code for
  selects with vector conditions, which may be produced
  by floordiv op lowering etc..

- This already works for llvm BE, as llvm select instruction
  supports vector conditions.

Signed-off-by: Wei Pan <[email protected]>
@wpan11nv
Copy link
Contributor Author

wpan11nv commented Mar 9, 2020

Kindly ping. Can someone help review this PR?

@jmorrill
Copy link
Contributor

jmorrill commented Mar 9, 2020

While I am not qualified to give a review of this, I have applied your changes on this PR and it I was able to compile a mxnet model to a cuda tvm graphruntime.

autotvm also looks like it is working correctly.

@jmorrill
Copy link
Contributor

jmorrill commented Mar 9, 2020

Kindly ping. Can someone help review this PR?

@wpan11nv, I think you have to tag some reviewers from here:
https://github.com/apache/incubator-tvm/blob/master/CONTRIBUTORS.md#reviewers

@wpan11nv
Copy link
Contributor Author

wpan11nv commented Mar 9, 2020

@vinx13 Could you help review this PR?

@wpan11nv
Copy link
Contributor Author

wpan11nv commented Mar 9, 2020

While I am not qualified to give a review of this, I have applied your changes on this PR and it I was able to compile a mxnet model to a cuda tvm graphruntime.

autotvm also looks like it is working correctly.

@jmorrill Thanks for confirming this fix!

@masahi masahi self-assigned this Mar 10, 2020
@masahi
Copy link
Member

masahi commented Mar 10, 2020

@wpan11nv I'll take a look (@vinx13 is currently a grad student, busy). I have a commit right

@masahi
Copy link
Member

masahi commented Mar 11, 2020

LGTM, but it seems indentation is broken in cuda source codegen. Not important, but would be nice to clean it up.

extern "C" __global__ void default_function_kernel0( float* __restrict__ B,  float* __restrict__ A) {
    float4 _1;
            int4 _2 = make_int4(37, 37, 37, 37);
            int4 _3 = make_int4(0, 0, 0, 0);
            ushort4 _4;
            _4.x = (_2.x>=_3.x);
            _4.y = (_2.y>=_3.y);
            _4.z = (_2.z>=_3.z);
            _4.w = (_2.w>=_3.w);
              int4 _5 = (make_int4)(((((int)threadIdx.x) * 4))+(1*0), ((((int)threadIdx.x) * 4))+(1*1), ((((int)threadIdx.x) * 4))+(1*2), ((((int)threadIdx.x) * 4))+(1*3));
              int4 _6 = make_int4(37, 37, 37, 37);
              int4 _7;
              _7.x = (_5.x%_6.x);
              _7.y = (_5.y%_6.y);
              _7.z = (_5.z%_6.z);
              _7.w = (_5.w%_6.w);
            int4 _8 = make_int4(0, 0, 0, 0);
            ushort4 _9;
            _9.x = (_7.x>=_8.x);
            _9.y = (_7.y>=_8.y);
            _9.z = (_7.z>=_8.z);
            _9.w = (_7.w>=_8.w);
          ushort4 _10;
          _10.x = (_4.x&&_9.x);
          _10.y = (_4.y&&_9.y);
          _10.z = (_4.z&&_9.z);
          _10.w = (_4.w&&_9.w);
            int4 _11 = make_int4(37, 37, 37, 37);
            int4 _12 = make_int4(0, 0, 0, 0);
            ushort4 _13;
            _13.x = (_11.x<_12.x);
            _13.y = (_11.y<_12.y);
            _13.z = (_11.z<_12.z);
            _13.w = (_11.w<_12.w);
              int4 _14 = (make_int4)(((((int)threadIdx.x) * 4))+(1*0), ((((int)threadIdx.x) * 4))+(1*1), ((((int)threadIdx.x) * 4))+(1*2), ((((int)threadIdx.x) * 4))+(1*3));
              int4 _15 = make_int4(37, 37, 37, 37);
              int4 _16;
              _16.x = (_14.x%_15.x);
              _16.y = (_14.y%_15.y);
              _16.z = (_14.z%_15.z);
              _16.w = (_14.w%_15.w);
            int4 _17 = make_int4(0, 0, 0, 0);
            ushort4 _18;
            _18.x = (_16.x<=_17.x);
            _18.y = (_16.y<=_17.y);
            _18.z = (_16.z<=_17.z);
            _18.w = (_16.w<=_17.w);
          ushort4 _19;
          _19.x = (_13.x&&_18.x);
          _19.y = (_13.y&&_18.y);
          _19.z = (_13.z&&_18.z);
          _19.w = (_13.w&&_18.w);
        ushort4 _20;
        _20.x = (_10.x||_19.x);
        _20.y = (_10.y||_19.y);
        _20.z = (_10.z||_19.z);
        _20.w = (_10.w||_19.w);
        int4 _21 = (make_int4)(((((int)threadIdx.x) * 4))+(1*0), ((((int)threadIdx.x) * 4))+(1*1), ((((int)threadIdx.x) * 4))+(1*2), ((((int)threadIdx.x) * 4))+(1*3));
        int4 _22 = make_int4(37, 37, 37, 37);
        int4 _23;
        _23.x = (_21.x/_22.x);
        _23.y = (_21.y/_22.y);
        _23.z = (_21.z/_22.z);
        _23.w = (_21.w/_22.w);
          int4 _24 = (make_int4)(((((int)threadIdx.x) * 4))+(1*0), ((((int)threadIdx.x) * 4))+(1*1), ((((int)threadIdx.x) * 4))+(1*2), ((((int)threadIdx.x) * 4))+(1*3));
          int4 _25 = make_int4(37, 37, 37, 37);
          int4 _26;
          _26.x = (_24.x/_25.x);
          _26.y = (_24.y/_25.y);
          _26.z = (_24.z/_25.z);
          _26.w = (_24.w/_25.w);
        int4 _27 = make_int4(1, 1, 1, 1);
        int4 _28;
        _28.x = (_26.x-_27.x);
        _28.y = (_26.y-_27.y);
        _28.z = (_26.z-_27.z);
        _28.w = (_26.w-_27.w);
      int4 _29;
      _29.x = (bool(_20.x)?_23.x:_28.x);
      _29.y = (bool(_20.y)?_23.y:_28.y);
      _29.z = (bool(_20.z)?_23.z:_28.z);
      _29.w = (bool(_20.w)?_23.w:_28.w);
    _1.x = A[_29.x];
    _1.y = A[_29.y];
    _1.z = A[_29.z];
    _1.w = A[_29.w];
  (( float4*)(B + ((((int)threadIdx.x) * 4))))[0] = _1;
}

@masahi masahi merged commit afa8417 into apache:master Mar 11, 2020
@masahi
Copy link
Member

masahi commented Mar 11, 2020

Thanks @wpan11nv @jmorrill @vinx13

@wpan11nv
Copy link
Contributor Author

LGTM, but it seems indentation is broken in cuda source codegen. Not important, but would be nice to clean it up.

extern "C" __global__ void default_function_kernel0( float* __restrict__ B,  float* __restrict__ A) {
    float4 _1;
            int4 _2 = make_int4(37, 37, 37, 37);
            int4 _3 = make_int4(0, 0, 0, 0);
            ushort4 _4;
            _4.x = (_2.x>=_3.x);
            _4.y = (_2.y>=_3.y);
            _4.z = (_2.z>=_3.z);
            _4.w = (_2.w>=_3.w);
              int4 _5 = (make_int4)(((((int)threadIdx.x) * 4))+(1*0), ((((int)threadIdx.x) * 4))+(1*1), ((((int)threadIdx.x) * 4))+(1*2), ((((int)threadIdx.x) * 4))+(1*3));
              int4 _6 = make_int4(37, 37, 37, 37);
              int4 _7;
              _7.x = (_5.x%_6.x);
              _7.y = (_5.y%_6.y);
              _7.z = (_5.z%_6.z);
              _7.w = (_5.w%_6.w);
            int4 _8 = make_int4(0, 0, 0, 0);
            ushort4 _9;
            _9.x = (_7.x>=_8.x);
            _9.y = (_7.y>=_8.y);
            _9.z = (_7.z>=_8.z);
            _9.w = (_7.w>=_8.w);
          ushort4 _10;
          _10.x = (_4.x&&_9.x);
          _10.y = (_4.y&&_9.y);
          _10.z = (_4.z&&_9.z);
          _10.w = (_4.w&&_9.w);
            int4 _11 = make_int4(37, 37, 37, 37);
            int4 _12 = make_int4(0, 0, 0, 0);
            ushort4 _13;
            _13.x = (_11.x<_12.x);
            _13.y = (_11.y<_12.y);
            _13.z = (_11.z<_12.z);
            _13.w = (_11.w<_12.w);
              int4 _14 = (make_int4)(((((int)threadIdx.x) * 4))+(1*0), ((((int)threadIdx.x) * 4))+(1*1), ((((int)threadIdx.x) * 4))+(1*2), ((((int)threadIdx.x) * 4))+(1*3));
              int4 _15 = make_int4(37, 37, 37, 37);
              int4 _16;
              _16.x = (_14.x%_15.x);
              _16.y = (_14.y%_15.y);
              _16.z = (_14.z%_15.z);
              _16.w = (_14.w%_15.w);
            int4 _17 = make_int4(0, 0, 0, 0);
            ushort4 _18;
            _18.x = (_16.x<=_17.x);
            _18.y = (_16.y<=_17.y);
            _18.z = (_16.z<=_17.z);
            _18.w = (_16.w<=_17.w);
          ushort4 _19;
          _19.x = (_13.x&&_18.x);
          _19.y = (_13.y&&_18.y);
          _19.z = (_13.z&&_18.z);
          _19.w = (_13.w&&_18.w);
        ushort4 _20;
        _20.x = (_10.x||_19.x);
        _20.y = (_10.y||_19.y);
        _20.z = (_10.z||_19.z);
        _20.w = (_10.w||_19.w);
        int4 _21 = (make_int4)(((((int)threadIdx.x) * 4))+(1*0), ((((int)threadIdx.x) * 4))+(1*1), ((((int)threadIdx.x) * 4))+(1*2), ((((int)threadIdx.x) * 4))+(1*3));
        int4 _22 = make_int4(37, 37, 37, 37);
        int4 _23;
        _23.x = (_21.x/_22.x);
        _23.y = (_21.y/_22.y);
        _23.z = (_21.z/_22.z);
        _23.w = (_21.w/_22.w);
          int4 _24 = (make_int4)(((((int)threadIdx.x) * 4))+(1*0), ((((int)threadIdx.x) * 4))+(1*1), ((((int)threadIdx.x) * 4))+(1*2), ((((int)threadIdx.x) * 4))+(1*3));
          int4 _25 = make_int4(37, 37, 37, 37);
          int4 _26;
          _26.x = (_24.x/_25.x);
          _26.y = (_24.y/_25.y);
          _26.z = (_24.z/_25.z);
          _26.w = (_24.w/_25.w);
        int4 _27 = make_int4(1, 1, 1, 1);
        int4 _28;
        _28.x = (_26.x-_27.x);
        _28.y = (_26.y-_27.y);
        _28.z = (_26.z-_27.z);
        _28.w = (_26.w-_27.w);
      int4 _29;
      _29.x = (bool(_20.x)?_23.x:_28.x);
      _29.y = (bool(_20.y)?_23.y:_28.y);
      _29.z = (bool(_20.z)?_23.z:_28.z);
      _29.w = (bool(_20.w)?_23.w:_28.w);
    _1.x = A[_29.x];
    _1.y = A[_29.y];
    _1.z = A[_29.z];
    _1.w = A[_29.w];
  (( float4*)(B + ((((int)threadIdx.x) * 4))))[0] = _1;
}

Yes, I noticed that indention issue too. I will have a look. Thanks!

@wpan11nv wpan11nv deleted the select_cg branch March 19, 2020 18:18
trevor-m pushed a commit to trevor-m/tvm that referenced this pull request Apr 16, 2020
- This patch allows CUDA backend to emit correct code for
  selects with vector conditions, which may be produced
  by floordiv op lowering etc..

- This already works for llvm BE, as llvm select instruction
  supports vector conditions.

Signed-off-by: Wei Pan <[email protected]>
zhiics pushed a commit to neo-ai/tvm that referenced this pull request Apr 17, 2020
- This patch allows CUDA backend to emit correct code for
  selects with vector conditions, which may be produced
  by floordiv op lowering etc..

- This already works for llvm BE, as llvm select instruction
  supports vector conditions.

Signed-off-by: Wei Pan <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants