Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Improve Canonical Simplification to Handle Fused Pattern #1711

Closed
ke1337 opened this issue Sep 13, 2018 · 9 comments
Closed

[ARITH] Improve Canonical Simplification to Handle Fused Pattern #1711

ke1337 opened this issue Sep 13, 2018 · 9 comments

Comments

@ke1337
Copy link
Contributor

ke1337 commented Sep 13, 2018

I have following C++ code when testing TVM in CUDA:

    tvm::Array<tvm::Expr> a_shape1 = {5, 6};
    // tvm::Array<tvm::Expr> a_shape1 = {5, 2, 3};
    tvm::Tensor tvm_X = tvm::placeholder(a_shape1, tvm::Float(32), "A");
    tvm::Tensor tvm_Y = ::topi::where(less(0, tvm_X), 1 / (1 + exp(negative(tvm_X))), 1 - 1 / (1 + exp(tvm_X)));
    auto target1 = tvm::target::cuda();
    auto S1 = topi::cuda::schedule_injective(target1, {tvm_Y});

    auto args1 = tvm::Array<tvm::Tensor>({tvm_X, tvm_Y});
    std::unordered_map<tvm::Tensor, tvm::Buffer> binds1;
    auto config1 = tvm::build_config();
    config1->restricted_func = true;
    auto lowered1 = tvm::lower(S1, args1, "Sigmoid", binds1, config1);

    std::cout << lowered1[0]->body << std::endl;

When the input shape is 2D (5,6), the lowered function looks close to handwritten kernel:

  if ((threadIdx.x < 30)) {
    tensor[threadIdx.x] = tvm_if_then_else(((0.000000f < A[threadIdx.x]) == (uint1)0), (1.000000f - (1.000000f/(exp(A[threadIdx.x]) + 1.000000f))), (1.000000f/(exp((0.000000f - A[threadIdx.x])) + 1.000000f)))
  }

However for 3D input shape (5, 2, 3), the lowered function looks different:

  if ((threadIdx.x < 30)) {
    tensor[(((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3)))] = tvm_if_then_else(((0.000000f < A[(((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3)))]) == (uint1)0), (1.000000f - (1.000000f/(exp(A[(((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3)))]) + 1.000000f))), (1.000000f/(exp((0.000000f - A[(((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3)))])) + 1.000000f)))
  }

From my reading of injective schedule, it seems all input axes are fused before split, so the two cases above should have identical code gen. Is my understanding correct?

@ke1337 ke1337 changed the title CUDA schedule_injective creates different code after lower with different input shape CUDA schedule_injective creates different code with different input shape Sep 13, 2018
@merrymercy
Copy link
Member

merrymercy commented Sep 13, 2018

Ideally (((threadIdx.x/6)*6) + ((((threadIdx.x/3) % 2)*3) + (threadIdx.x % 3))) should be simplified to threadIdx.x since they are equivalent. But the simplifier in tvm cannot handle this case. (cc @tqchen )

Although it introduces some extra arithmetic operations, in practice we don't observe performance regression. So it is still okay.

@tqchen tqchen changed the title CUDA schedule_injective creates different code with different input shape [ARITH] Improve Canonical Simplification to Handle Fused Pattern Sep 13, 2018
@tqchen
Copy link
Member

tqchen commented Sep 13, 2018

as far as I recall, there is some ability in the buffer index fetch to simplify such expressions, by @sxjscience maybe someone can followup on this

@sxjscience
Copy link
Member

Yes, I've met with this problem and have written the following code https://github.com/dmlc/tvm/blob/master/src/lang/buffer.cc#L152-L220 to optimize some predefined patterns.

@xqdan
Copy link
Contributor

xqdan commented Sep 14, 2018

I've had similar issue, and my pattern looks more complicated, codegen mechanism in tvm can't handle this, so we chose low level ir builder for 3d conv.
Anyway, @sxjscience , could you take a look at this pattern, can we transform this like just you did?

for (j, 0, 32) {
 for (k, 0, 2) {
   for (m, 0, 16) {
     for (n, 0, 16) {
       Apad5d[((((j*512) + (k*256)) + (m*16)) + n)] = select((((((((((bo.outer*512) + (j*16)) + m)/32) + (((((ko.outer*7) + (k*16)) + n) % 25)/5)) >= 2) && (((((((bo.outer*512) + (j*16)) + m)/32) + (((((ko.outer*7) + (k*16)) + n) % 25)/5)) - 2) < 32)) && (((((j*16) + m) % 32) + (((((ko.outer*7) + (k*16)) + n) % 25) % 5)) >= 2)) && ((((((j*16) + m) % 32) + (((((ko.outer*7) + (k*16)) + n) % 25) % 5)) - 2) < 32)), A.local.L1[((((((((((bo.outer*512) + (j*16)) + m)/32)*576) + ((((j*16) + m) % 32)*16)) + (((((ko.outer*32) + (k*16)) + n)/400)*20736)) + ((((((ko.outer*7) + (k*16)) + n) % 25)/5)*576)) + (((((ko.outer*32) + (k*16)) + n)/25) % 16)) + ((((((ko.outer*7) + (k*16)) + n) % 25) % 5)*16))], 0.000000h)
     }
   }
 }
}

Thanks,

@sxjscience
Copy link
Member

sxjscience commented Sep 14, 2018 via email

@merrymercy
Copy link
Member

merrymercy commented Sep 14, 2018

I found a related Halide PR (halide/Halide#2845) which might be interesting.
They introduce template for rewrite so we can define new rules as follows

rewrite((x + y) + w < x + z, y + w < z)
rewrite(select(x, y, z) + select(x, w, u), select(x, y + w, z + u))

@xqdan
Copy link
Contributor

xqdan commented Sep 14, 2018

@sxjscience it's im2col convolution

@tqchen
Copy link
Member

tqchen commented Sep 14, 2018

Most of the simplification we talked about here involves bound checking as well as the arithmetic template, which is harder than the simple rewrites. I wanted to do it for quite a while ago, maybe it is a good time to rethink our arithmetic simplifier to handle these cases

@tqchen
Copy link
Member

tqchen commented Feb 12, 2019

Consolidate this issue to #2588

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

5 participants