-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TASK] TextPrinter Support for Unified TVM IRs #48
Comments
So ideally we want text format & hybrid script to be two different printing methods for TIR which are both parseable? |
yes. but as a path for progression, we can
|
What is hybrid bidirection? Bidirectional translation between text format & hybrid? |
Sorry, i meant hybrid parsing and printing |
Generally, it is ok to use the current ReprPrinter as long as we do some minor mutation. However, there are some details to be determined. Function Syntax
I think both are not the perfect solution but I have no other ideas. Would love to listen to your opinions AnnotationI'd like to add a kind of annotation in function to show whether it is a Relay or TIR function. One possible syntax would be
Function BodyThere is no doubt that the relay function and TIR function have different stmt, expr. (e.g. Let Binding in relay and BufferLoad in TIR). So I think we can ignore the difference of the function body but try to make the function itself unified. |
|
FYI, I am in the process of refactoring out tensor field(func) in Realize Provide and replace them by BufferRealize BufferStore and BufferLoad. So we don’t need to consider supporting these nodes in the text format |
Related change in the upstream apache/tvm#5372 |
A comparison between repr's format and text format(which is problematic now) [23:41:44] /home/spectrometer/tvm-upstream/src/te/schedule/bound.cc:128: not in feed graph consumer = compute(B, 0x1fb5cf0)
PrimFunc([A0, A1, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
// attr [B.v0] storage_scope = "global"
allocate B.v0[float32 * n]
// attr [B.v1] storage_scope = "global"
allocate B.v1[float32 * n]
for (i, 0, m) {
for (j, 0, n) {
B.v0[j] = (A0[((i*stride) + (j*stride))] + 2f)
B.v1[j] = (A0[((i*stride) + (j*stride))]*3f)
}
for (j, 0, n) {
C[((i*stride) + (j*stride))] = (A1[((i*stride) + (j*stride))] + B.v0[j])
}
}
}
[23:41:44] /home/spectrometer/tvm-upstream/src/te/schedule/bound.cc:128: not in feed graph consumer = compute(B, 0x1fb5cf0)
class Module:
PrimFunc(%A0: Buffer([%m, %n], "float32", name=%A01),
%A1: Buffer([%m, %n], "float32", name=%A11),
%C: Buffer([%m, %n], "float32", name=%C1)) {
attr [%B.v0] "storage_scope" = "global" {
allocate(%B.v0, "float32", [%n]) if "bool"(1) {
attr [%B.v1] "storage_scope" = "global" {
allocate(%B.v1, "float32", [%n]) if "bool"(1) {
for (%i, 0, %m) "serial" {
for (%j, 0, %n) "serial" {
%B.v0[%j] = (%A0[((%i*%stride) + (%j*%stride1))] + "float32"(2)) if "bool"(1)
%B.v1[%j] = (%A0[((%i*%stride) + (%j*%stride1))]*"float32"(3)) if "bool"(1)
}
for (%j1, 0, %n) "serial" {
%C[((%i*%stride2) + (%j1*%stride3))] = (%A1[((%i*%stride4) + (%j1*%stride5))] + %B.v0[%j1]) if "bool"(1)
}
}
}
}
}
}
} Several Problems1. The name of VarsFor readability, the name_hint of var is useful. But 2. dtype of VarsI'm not very clear where Var is being used in the upstream AST. As long as I can see now
When printing a var, the best way is to just give its name_hint, but in principle we have to determine its dtype as well. But in most cases, dtype doesn't matter. I'm not sure about this point. |
|
Some additional thoughts after seeing the example.
A strawman, I am not too happy with the syntax attr ("storage_scope", %Bv0) = "global" |
For vars like %A0: Buffer([%m, %n], "float32", name=%A0_1),
%A1: Buffer([%m, %n], "float32", name=%A1_1),
%C: Buffer([%m, %n], "float32", name=%C_1) One way is to print Type for the first time we encounter the Var %A0: Buffer([%m : int32, %n : int32], "float32", name=%A0_1),
%A1: Buffer([%m, %n], "float32", name=%A1_1),
%C: Buffer([%m, %n], "float32", name=%C_1) Another way is to declare first their defs somewhere else. The third way is to put it in meta. I think all the ways are somewhat strange. And for vars like %A0[((%i*%stride) + (%j*%stride1)) It looks like more a placeholder for unknown buffer shape. Does their Type/dtype matter? |
I think the first encouter or declare def somewhere else makes sense. Formally, we should not put buffer_map in the argument list(that makes Buffer a type) but they are not types, we might want to have some seciton like c++'s initializer list (the corresponding hybrid script part is the buffer bind) |
It seems that relay's text printer doesn't support print PrimType and PointerType. Will this be fixed in the upstream? |
I think it will be up to us to fix them :) |
Discussion on printing Buffer/Var, META usage&recoverAn inevitable problem in text format is how to print buffer_map( PrimFunc(%A0: handle, %A1: handle, %C: handle)
%C: Buffer(%C_1, [%m, %n], float32, [%stride, %stride_1], %C_2)
%A0: Buffer(%A0_1, [%m, %n], float32, [%stride_2, %stride_3], %A0_2)
%A1: Buffer(%A1_1, [%m, %n], float32, [%stride_4, %stride_5], %A1_2) {
attr [%B_v0] "storage_scope" = "global" {
allocate(%B_v0, float32, [%n]) {
attr [%B_v1] "storage_scope" = "global" {
allocate(%B_v1, float32, [%n]) {
for (%i, 0, %m) {
for (%j, 0, %n) {
%B_v0[%j] = (%A0_1[((%i*%stride_2) + (%j*%stride_3))] + float32(2))
%B_v1[%j] = (%A0_1[((%i*%stride_2) + (%j*%stride_3))]*float32(3))
}
for (%j_1, 0, %n) {
%C_1[((%i*%stride) + (%j_1*%stride_1))] = (%A1_1[((%i*%stride_4) + (%j_1*%stride_5))] + %B_v0[%j_1])
}
}
}
}
}
}
} C0. Print all Buffer information outPros:If we avoid using META in printing, then the text format is fully writable and modifiable to the user, as long as the user follows the syntax rules. Cons:It's really long to print all Buffer info, which hurts readability. C1. Put Buffer in METAPros:Clearer for buffer_map printing Cons:
I further propose several ways to deal with the cons of C1. C1.1 Scan the Vars in Buffer in advance and store a name map for them into METAWhen printing, we will give each Var we encouneter a unique name_hint. To be able to print C1.2 Scan all the Vars in AST and store a name map for themConsider such a scenario in general cases: But we can do this more aggressively to totally alleviate this problem, we can just scan all the Vars in AST and store a name map for them into META. Then for all the Var occurances, we can just print a name_hint no matter how we use this Var. Pros: Simple and works generally. Note that it also works if we don't put the Buffer into META. |
I will land a version that avoids using META then. |
Is |
Type to dtype: https://github.com/apache/incubator-tvm/blob/master/include/tvm/tir/op.h#L70 (note that it is always safe to construct a Var with Type, because Type is more fine grained). It is fine to print a Var without Type(only with dtype) but load it back containing Type. We cannot always deduce Type from dtype though |
btw, if the identifier doesn't begin with %, the lexer will confuse it with CNAME used in meta. So I will keep % atm. |
I see, we will need to at least print mod as truncmod instead of % then. |
I've implemented a printer&parser for my current syntax. I'll attach the g4 file below for discussions in our next meeting. Note that it hasn't support META in Attr, and Attr(and DictAttr) is the only place that will result in nodes being put in META. Several Points:
|
PrimFunc(%A0_1: handle, %A1_1: handle, %C_1: handle) -> ()
buffers={%A1: Buffer(%A1_2: handle, float32, [%m: int32, %n: int32], [%stride: int32, %stride_1: int32], 0, "global", 128, 1, "auto"),
%C: Buffer(%C_2: handle, float32, [%m, %n], [%stride_2: int32, %stride_3: int32], 0, "global", 128, 1, "auto"),
%A0: Buffer(%A0_2: handle, float32, [%m, %n], [%stride_4: int32, %stride_5: int32], 0, "global", 128, 1, "auto")}
buffer_map={%C_1: %C, %A0_1: %A0, %A1_1: %A1} {
attr [%B_v0: handle] "storage_scope" = "global" {
allocate(%B_v0, float32, [%n]) {
attr [%B_v1: handle] "storage_scope" = "global" {
allocate(%B_v1, float32, [%n]) {
for (%i: int32, 0, %m) {
for (%j: int32, 0, %n) {
%B_v0[%j] = (%A0_2[((%i*%stride_4) + (%j*%stride_5))] + float32(2))
%B_v1[%j] = (%A0_2[((%i*%stride_4) + (%j*%stride_5))]*float32(3))
}
for (%j_1: int32, 0, %n) {
%C_2[((%i*%stride_2) + (%j_1*%stride_3))] = (%A1_2[((%i*%stride) + (%j_1*%stride_1))] + %B_v0[%j_1])
}
}
}
}
}
}
} grammar tir;
module
: function* METADATA?
;
function
: PRIMFUNC '(' varNode (',' varNode)* ')' '->' typeExpr
'buffers' '=' '{' buffer_list? '}'
'buffer_map' '=' '{' buffer_map_list? '}'
body
;
buffer_list
: IDENTIFIER ':' bufferNode (',' IDENTIFIER ':' bufferNode)*
;
buffer_map_list
: IDENTIFIER ':' IDENTIFIER (',' IDENTIFIER ':' IDENTIFIER)*
;
body
: '{' stmtNode* '}'
;
node
: primExprNode
| stmtNode
| rangeNode
| commReducerNode
| arrayNode
| itervarNode
| bufferNode
| meta
;
arrayNode
: '[' node (',' node)* ']'
;
stmtNode
: LET varNode '=' primExprNode body # letStmt
| ATTR '[' (varNode|itervarNode) ']' STRINGIMM '=' primExprNode body # attrStmt
| ASSERT '(' primExprNode ',' primExprNode ')' body # assertStmt
| varNode arrayNode '=' primExprNode (IF primExprNode)? # storeStmt
| ALLOCATE '(' varNode ',' DTYPE ',' arrayNode ')' (IF primExprNode)? body # allocateStmt
| FREE '(' varNode ')' # freeStmt
| REALIZE '(' IDENTIFIER ',' arrayNode ')' (IF primExprNode)? body # bufferRealizeStmt
| IF primExprNode body (ELSE body)? # ifThenElseStmt
| EVALUATE '(' primExprNode ')' # evaluateStmt
| FOR '(' varNode ',' primExprNode ',' primExprNode ')' STRINGIMM? body # forStmt
| PREFETCH '(' IDENTIFIER ',' arrayNode ')' # prefetchStmt
;
primExprNode
: CAST '(' DTYPE ',' primExprNode ')' # castExpr
| FLOORDIV '(' primExprNode ',' primExprNode ')' # floordivExpr
| FLOORMOD '(' primExprNode ',' primExprNode ')' # floormodExpr
| MIN '(' primExprNode ',' primExprNode ')' # minExpr
| MAX '(' primExprNode ',' primExprNode ')' # maxExpr
| SELECT '(' primExprNode ',' primExprNode ',' primExprNode ')' # selectExpr
| varNode arrayNode (IF primExprNode)? # loadExpr
| RAMP '(' primExprNode ',' primExprNode ',' INT ')' # rampExpr
| BROADCAST '(' primExprNode ',' INT ')' # broadcastExpr
| LET varNode '=' primExprNode IN primExprNode # letExpr
| CALL '(' DTYPE ',' STRINGIMM ',' arrayNode ',' STRINGIMM ',' INT ')' # callExpr
| SHUFFLE '(' arrayNode ',' arrayNode ')' # shuffleExpr
| REDUCE '(' commReducerNode ',' arrayNode ',' arrayNode ',' INT ')' (IF primExprNode)? # reduceExpr
| <assoc=right> op='!' primExprNode # notExpr
| src1 = primExprNode op=('*' | '/' | '%') src2 = primExprNode # binExpr
| src1 = primExprNode op=('+' | '-') src2 = primExprNode # binExpr
| src1 = primExprNode op=('<' | '>') src2 = primExprNode # binExpr
| src1 = primExprNode op=('<=' | '>=') src2 = primExprNode # binExpr
| src1 = primExprNode op=('==' | '!=') src2 = primExprNode # binExpr
| src1 = primExprNode op='&&' src2 = primExprNode # binExpr
| src1 = primExprNode op='||' src2 = primExprNode # binExpr
| varNode # varExpr
| meta # metaExpr
| immediate # immExpr
| '(' primExprNode ')' # parenExpr
;
varNode
: IDENTIFIER (':' typeExpr)?
;
commReducerNode
: COMMREDUCER '(' arrayNode ',' arrayNode ',' arrayNode ',' arrayNode ')'
;
bufferNode
: BUFFER '(' varNode ',' DTYPE ',' arrayNode ',' arrayNode ',' primExprNode ','
STRINGIMM ',' INT ',' INT ',' STRINGIMM ')'
;
rangeNode
: primExprNode ':' primExprNode
;
itervarNode
: ITERVAR '(' varNode ',' '[' rangeNode ']' ',' STRINGIMM ',' STRINGIMM ')'
;
immediate
: STRINGIMM # stringImm
| DTYPE '(' INT ')' # intImm
| DTYPE '(' FLOATIMM ')' # floatImm
| INT # int32Imm
| TRUELITERAL # trueLiteral
| FALSELITERAL # falseLiteral
;
meta
: 'meta' '[' CNAME ']' '[' INT ']'
;
// --- Type
typeExpr
: '(' ')' # tupleType
| DTYPE # primType
| 'Pointer' '(' typeExpr ')' # pointerType
;
METADATA: 'METADATA:' .*;
// --- Reserved words
MUL : '*' ;
DIV : '/' ;
ADD : '+' ;
SUB : '-' ;
MOD : '%' ;
LT : '<' ;
GT : '>' ;
LE : '<=' ;
GE : '>=' ;
EQ : '==' ;
NE : '!=' ;
AND : '&&' ;
OR : '||' ;
IN : 'in';
PRIMFUNC : 'PrimFunc';
IF : 'if';
ELSE : 'else';
TRUELITERAL : 'true';
FALSELITERAL : 'false';
LET : 'let';
ATTR : 'attr';
ASSERT : 'assert';
ALLOCATE : 'allocate';
FREE : 'free';
REALIZE : 'realize';
EVALUATE : 'evaluate';
FOR : 'for';
PREFETCH : 'prefetch';
CAST : 'cast';
FLOORDIV : 'floordiv';
FLOORMOD : 'floormod';
MIN : 'min';
MAX : 'max';
SELECT : 'select';
LOAD : 'load';
RAMP : 'ramp';
BROADCAST : 'broadcast';
CALL : 'call';
SHUFFLE : 'shuffle';
REDUCE : 'reduce';
BUFFER : 'Buffer';
COMMREDUCER : 'comm_reducer';
ITERVAR : 'IterVar';
fragment DIGIT
: [0-9]
;
fragment NAT
: DIGIT+
;
INT
: '-'? DIGIT+
;
fragment EXP
: [eE] [+\-]? NAT
;
FLOATIMM
: INT ('.' NAT)? EXP?
;
fragment LETTER
: [a-zA-Z]
;
STRINGIMM
: '"' ('\\n' | '\\\\' | '\\"' | .)*? '"'
;
DTYPE
: 'float' NAT ('x' NAT)?
| 'int' NAT ('x' NAT)?
| 'uint' NAT ('x' NAT)?
| 'bool'
| 'handle'
;
IDENTIFIER
: '%' [a-zA-Z] + [a-zA-Z_0-9]*
;
CNAME : ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;
WHITESPACE
: [ \t\n\r]+ -> skip
;
|
I tried to put vars declarations ahead of vars = {int32: [%stride_3", "%stride_2", "%stride_4", "%stride_1", "%n", "%stride_5", "%m", "%stride", "%j]} I think we'd better use the first encounter declaration style. To shorten the declaration of a buffer, using kw param seems good buffers = {%A1: Buffer(%A1_2, float32, [%m : int32, %n : int32], [%stride_2 : int32, %stride_3 : int32]),
%C: Buffer(%C_2, float32, [%m, %n], [%stride_4 : int32, %stride_5 : int32]),
%A0: Buffer(%A0_2, float32, [%m, %n], [%stride : int32, %stride_1 : int32])} |
how about something like The first encouter also looks ok. We could also introduce default types for indices(i32 and perhaps change to i64 in the future |
Example for Conv on GPU primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
attr = {"tir.noalias": bool(1), "global_symbol": "main"}
buffers = {B: Buffer(B_2: handle, float32, [14, 14, 512, 256], []),
W: Buffer(W_2: handle, float32, [3, 3, 256, 512], []),
A: Buffer(A_2: handle, float32, [14, 14, 256, 256], [])}
buffer_map = {B_1: B, A_1: A, W_1: W} {
attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
attr [B.local: handle] "storage_scope" = "local";
allocate(B.local, float32, [64]) {
attr [Apad.shared: handle] "storage_scope" = "shared";
allocate(Apad.shared, float32, [512]) {
attr [W.shared: handle] "storage_scope" = "shared";
allocate(W.shared, float32, [512]) {
attr [Apad.shared.local: handle] "storage_scope" = "local";
allocate(Apad.shared.local, float32, [8]) {
attr [W.shared.local: handle] "storage_scope" = "local";
allocate(W.shared.local, float32, [8]) {
attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 8;
attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 4;
attr [IterVar(threadIdx.y: int32, [0:8], "ThreadIndex", "threadIdx.y")] "thread_extent" = 8;
attr [IterVar(threadIdx.x: int32, [0:8], "ThreadIndex", "threadIdx.x")] "thread_extent" = 8 {
for (ff.c.init: int32, 0, 4) {
for (nn.c.init: int32, 0, 4) {
B.local[((ff.c.init*4) + nn.c.init)] = float32(0)
B.local[(((ff.c.init*4) + nn.c.init) + 32)] = float32(0)
B.local[(((ff.c.init*4) + nn.c.init) + 16)] = float32(0)
B.local[(((ff.c.init*4) + nn.c.init) + 48)] = float32(0)
}
}
for (rc.outer: int32, 0, 32) {
for (ry: int32, 0, 3) {
for (rx: int32, 0, 3) {
for (ax3.inner.outer: int32, 0, 2) {
Apad.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer*4)), 1, 4)] = call(float32x4, "tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + ry)) and ((floordiv(blockIdx.z, 14) + ry) < 15)) and (1 <= (rx + floormod(blockIdx.z, 14)))) and ((rx + floormod(blockIdx.z, 14)) < 15)), load(float32x4, A_2[ramp((((((((((ry*917504) + (blockIdx.z*65536)) + (rx*65536)) + (rc.outer*2048)) + (threadIdx.y*256)) + (blockIdx.x*64)) + (threadIdx.x*8)) + (ax3.inner.outer*4)) - 983040), 1, 4)]), broadcast(float32(0), 4)], "pure_intrin", 0)
}
for (ax3.inner.outer_1: int32, 0, 2) {
W.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)] = load(float32x4, W_2[ramp((((((((ry*393216) + (rx*131072)) + (rc.outer*4096)) + (threadIdx.y*512)) + (blockIdx.y*64)) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)])
}
for (rc.inner: int32, 0, 8) {
for (ax3: int32, 0, 4) {
Apad.shared.local[ax3] = load(float32, Apad.shared[(((rc.inner*64) + (threadIdx.x*4)) + ax3)])
Apad.shared.local[(ax3 + 4)] = load(float32, Apad.shared[((((rc.inner*64) + (threadIdx.x*4)) + ax3) + 32)])
}
for (ax3_1: int32, 0, 4) {
W.shared.local[ax3_1] = load(float32, W.shared[(((rc.inner*64) + (threadIdx.y*4)) + ax3_1)])
W.shared.local[(ax3_1 + 4)] = load(float32, W.shared[((((rc.inner*64) + (threadIdx.y*4)) + ax3_1) + 32)])
}
for (ff.c: int32, 0, 4) {
for (nn.c: int32, 0, 4) {
B.local[((ff.c*4) + nn.c)] = (load(float32, B.local[((ff.c*4) + nn.c)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[ff.c])))
B.local[(((ff.c*4) + nn.c) + 32)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 32)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[(ff.c + 4)])))
B.local[(((ff.c*4) + nn.c) + 16)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 16)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[ff.c])))
B.local[(((ff.c*4) + nn.c) + 48)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 48)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[(ff.c + 4)])))
}
}
}
}
}
}
for (ff.inner.inner.inner: int32, 0, 4) {
for (nn.inner.inner.inner: int32, 0, 4) {
B_2[(((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)] = load(float32, B.local[((ff.inner.inner.inner*4) + nn.inner.inner.inner)])
B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8192)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 32)])
B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 32)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 16)])
B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8224)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 48)])
}
}
}
}
}
}
}
}
} |
GEMM on CPU primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
attr = {"tir.noalias": bool(1), "global_symbol": "main"}
buffers = {C: Buffer(C_2: handle, float32, [1024, 1024], []),
B: Buffer(B_2: handle, float32, [1024, 1024], []),
A: Buffer(A_2: handle, float32, [1024, 1024], [])}
buffer_map = {C_1: C, A_1: A, B_1: B} {
attr [packedB: handle] "storage_scope" = "global";
allocate(packedB, float32x32, [32768]) {
for (x: int32, 0, 32) "parallel" {
for (y: int32, 0, 1024) {
packedB[ramp(((x*32768) + (y*32)), 1, 32)] = load(float32x32, B_2[ramp(((y*1024) + (x*32)), 1, 32)])
}
}
for (x.outer: int32, 0, 32) "parallel" {
attr [C.global: handle] "storage_scope" = "global";
allocate(C.global, float32, [1024]) {
for (y.outer: int32, 0, 32) {
for (x.c.init: int32, 0, 32) {
C.global[ramp((x.c.init*32), 1, 32)] = broadcast(float32(0), 32)
}
for (k.outer: int32, 0, 256) {
for (x.c: int32, 0, 32) {
C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[(((x.outer*32768) + (x.c*1024)) + (k.outer*4))]), 32)*load(float32x32, packedB[ramp(((y.outer*32768) + (k.outer*128)), 1, 32)])))
C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 1)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 32), 1, 32)])))
C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 2)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 64), 1, 32)])))
C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 3)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 96), 1, 32)])))
}
}
for (x.inner: int32, 0, 32) {
for (y.inner: int32, 0, 32) {
C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = load(float32, C.global[((x.inner*32) + y.inner)])
}
}
}
}
}
}
} |
TensorCore for Conv primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
attr = {"tir.noalias": bool(1), "global_symbol": "main"}
buffers = {W: Buffer(W_2: handle, float16, [3, 3, 16, 32, 16, 16], []),
A: Buffer(A_2: handle, float16, [16, 14, 14, 16, 16, 16], []),
Conv: Buffer(Conv_2: handle, float32, [16, 14, 14, 32, 16, 16], [])}
buffer_map = {A_1: A, Conv_1: Conv, W_1: W} {
attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
attr [Conv.wmma.accumulator: handle] "storage_scope" = "wmma.accumulator";
allocate(Conv.wmma.accumulator, float32, [2048]) {
attr [Apad.shared: handle] "storage_scope" = "shared";
allocate(Apad.shared, float16, [12288]) {
attr [W.shared: handle] "storage_scope" = "shared";
allocate(W.shared, float16, [12288]) {
attr [Apad.shared.wmma.matrix_a: handle] "storage_scope" = "wmma.matrix_a";
allocate(Apad.shared.wmma.matrix_a, float16, [512]) {
attr [W.shared.wmma.matrix_b: handle] "storage_scope" = "wmma.matrix_b";
allocate(W.shared.wmma.matrix_b, float16, [1024]) {
attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
attr [IterVar(threadIdx.y: int32, [(nullptr)], "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
attr [IterVar(threadIdx.z: int32, [(nullptr)], "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
for (n.c.init: int32, 0, 2) {
for (o.c.init: int32, 0, 4) {
eval(call("tvm_fill_fragment", [Conv.wmma.accumulator, 16, 16, 16, ((n.c.init*4) + o.c.init), float32(0)], handle, "intrin", 0))
}
}
for (ic.outer: int32, 0, 8) {
for (kh: int32, 0, 3) {
for (ax2: int32, 0, 3) {
for (ax3: int32, 0, 2) {
for (ax4.ax5.fused.outer: int32, 0, 8) {
attr [IterVar(threadIdx.x: int32, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = call("tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + kh)) and ((floordiv(blockIdx.z, 14) + kh) < 15)) and (1 <= (ax2 + floormod(blockIdx.z, 14)))) and ((ax2 + floormod(blockIdx.z, 14)) < 15)), load(float16, A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)]), float16(0)], float16, "pure_intrin", 0)
}
}
}
for (ax1: int32, 0, 3) {
for (ax2_1: int32, 0, 2) {
attr [IterVar(threadIdx.x, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = load(float16x8, W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)])
}
}
for (ic.inner: int32, 0, 2) {
for (kw: int32, 0, 3) {
for (ax0: int32, 0, 2) {
eval(call("tvm_load_matrix_sync", [Apad.shared.wmma.matrix_a, 16, 16, 16, ax0, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), Apad.shared, ((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
}
for (ax3_1: int32, 0, 4) {
eval(call("tvm_load_matrix_sync", [W.shared.wmma.matrix_b, 16, 16, 16, ax3_1, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), W.shared, ((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
}
for (n.c: int32, 0, 2) {
for (o.c: int32, 0, 4) {
eval(call("tvm_mma_sync", [Conv.wmma.accumulator, ((n.c*4) + o.c), Apad.shared.wmma.matrix_a, n.c, W.shared.wmma.matrix_b, o.c, Conv.wmma.accumulator, ((n.c*4) + o.c)], handle, "intrin", 0))
}
}
}
}
}
}
for (n.inner: int32, 0, 2) {
for (o.inner: int32, 0, 4) {
eval(call("tvm_store_matrix_sync", [Conv.wmma.accumulator, 16, 16, 16, ((n.inner*4) + o.inner), call("tvm_access_ptr", [call("type_annotation", [], float32, "pure_intrin", 0), Conv_2, (((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)), 256, 2], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
}
}
}
}
}
}
}
}
} |
So far we have a text printer for relay. which allows us to print an IRModule into text format. On the TIR side, we still relies on the ReprPrinter.
This is issue is for upgrading the text printer so that we can print an IRModule that include PrimFunc(tir::Function in the upstream) as a text format. This will help us to enhance the demo experience.
Possible Design Points
Ideally we want to land a version in the mainline in about two to three weeks. @spectrometerHBH please see if it is possible for you and @Hzfengsy to coordinate a format and land a version, then we can pull back to the tensorIR
The text was updated successfully, but these errors were encountered: