-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[ARITH] Improve Canonical Simplification to Handle Fused Pattern #1711
Comments
Ideally Although it introduces some extra arithmetic operations, in practice we don't observe performance regression. So it is still okay. |
as far as I recall, there is some ability in the buffer index fetch to simplify such expressions, by @sxjscience maybe someone can followup on this |
Yes, I've met with this problem and have written the following code https://github.com/dmlc/tvm/blob/master/src/lang/buffer.cc#L152-L220 to optimize some predefined patterns. |
I've had similar issue, and my pattern looks more complicated, codegen mechanism in tvm can't handle this, so we chose low level ir builder for 3d conv.
Thanks, |
The current implementation haven’t considered the case that involves “<“ . Like “... < 32”. Also, what’s the simplified version of this pattern?(looks complicated)
Get Outlook for iOS<https://aka.ms/o0ukef>
…________________________________
From: xqdan <[email protected]>
Sent: Friday, September 14, 2018 7:24:53 PM
To: dmlc/tvm
Cc: Xingjian SHI; Mention
Subject: Re: [dmlc/tvm] [ARITH] Improve Canonical Simplification to Handle Fused Pattern (#1711)
I've had similar issue, and my pattern looks more complicated, codegen mechanism in tvm can't handle this, so we chose low level ir builder for 3d conv.
Anyway, @sxjscience<https://github.com/sxjscience> , could you take look this pattern, can we transform this like just you did?
for (j, 0, 32) {
for (k, 0, 2) {
for (m, 0, 16) {
for (n, 0, 16) {
Apad5d[((((j*512) + (k*256)) + (m*16)) + n)] = select((((((((((bo.outer*512) + (j*16)) + m)/32) + (((((ko.outer*7) + (k*16)) + n) % 25)/5)) >= 2) && (((((((bo.outer*512) + (j*16)) + m)/32) + (((((ko.outer*7) + (k*16)) + n) % 25)/5)) - 2) < 32)) && (((((j*16) + m) % 32) + (((((ko.outer*7) + (k*16)) + n) % 25) % 5)) >= 2)) && ((((((j*16) + m) % 32) + (((((ko.outer*7) + (k*16)) + n) % 25) % 5)) - 2) < 32)), A.local.L1[((((((((((bo.outer*512) + (j*16)) + m)/32)*576) + ((((j*16) + m) % 32)*16)) + (((((ko.outer*32) + (k*16)) + n)/400)*20736)) + ((((((ko.outer*7) + (k*16)) + n) % 25)/5)*576)) + (((((ko.outer*32) + (k*16)) + n)/25) % 16)) + ((((((ko.outer*7) + (k*16)) + n) % 25) % 5)*16))], 0.000000h)
}
}
}
}
Thanks,
—
You are receiving this because you were mentioned.
Reply to this email directly, view it on GitHub<#1711 (comment)>, or mute the thread<https://github.com/notifications/unsubscribe-auth/AE8D7kFMJCU5IOYHv_QZAn88OAiOFuTOks5ua5IFgaJpZM4WmofQ>.
|
I found a related Halide PR (halide/Halide#2845) which might be interesting.
|
@sxjscience it's im2col convolution |
Most of the simplification we talked about here involves bound checking as well as the arithmetic template, which is harder than the simple rewrites. I wanted to do it for quite a while ago, maybe it is a good time to rethink our arithmetic simplifier to handle these cases |
Consolidate this issue to #2588 |
I have following C++ code when testing TVM in CUDA:
When the input shape is 2D (5,6), the lowered function looks close to handwritten kernel:
However for 3D input shape (5, 2, 3), the lowered function looks different:
From my reading of injective schedule, it seems all input axes are fused before split, so the two cases above should have identical code gen. Is my understanding correct?
The text was updated successfully, but these errors were encountered: