Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TASK] TextPrinter Support for Unified TVM IRs #48

Closed
tqchen opened this issue Mar 25, 2020 · 28 comments
Closed

[TASK] TextPrinter Support for Unified TVM IRs #48

tqchen opened this issue Mar 25, 2020 · 28 comments

Comments

@tqchen
Copy link
Contributor

tqchen commented Mar 25, 2020

So far we have a text printer for relay. which allows us to print an IRModule into text format. On the TIR side, we still relies on the ReprPrinter.

This is issue is for upgrading the text printer so that we can print an IRModule that include PrimFunc(tir::Function in the upstream) as a text format. This will help us to enhance the demo experience.

Possible Design Points

  • Try to be as consistent as possible with relay (e.g. use %x for variables)
  • Think about parsing, but the first attempt does not have to be parsiable
  • Think about meta-data usage

Ideally we want to land a version in the mainline in about two to three weeks. @spectrometerHBH please see if it is possible for you and @Hzfengsy to coordinate a format and land a version, then we can pull back to the tensorIR

@tqchen tqchen changed the title [TASK] TextPrinter Support for All TVM IRs [TASK] TextPrinter Support for Unified TVM IRs Mar 25, 2020
@spectrometerHBH
Copy link
Collaborator

So ideally we want text format & hybrid script to be two different printing methods for TIR which are both parseable?

@tqchen
Copy link
Contributor Author

tqchen commented Mar 27, 2020

yes. but as a path for progression, we can

  • Have a text printer that is not parsable
  • Support hybrid bidirection
  • Support text parser

@spectrometerHBH
Copy link
Collaborator

spectrometerHBH commented Mar 27, 2020

What is hybrid bidirection? Bidirectional translation between text format & hybrid?

@tqchen
Copy link
Contributor Author

tqchen commented Mar 27, 2020

Sorry, i meant hybrid parsing and printing

@Hzfengsy
Copy link
Member

Hzfengsy commented Apr 7, 2020

Generally, it is ok to use the current ReprPrinter as long as we do some minor mutation. However, there are some details to be determined.

Function Syntax

  1. As we know, relay function has its arg and return type.
    fn(%x : Tensor[(10, 10), float32], %y : Tensor[(10, 10), float32])
               -> Tensor[(10, 10), float32] {
        add(%x, %y)
    }
    
    Of course, TIR can do the same thing to show the parameter type. The problem is TIR does not have a return type even the return value. It directly changes the buffer by mutating its element. Here are two options for this:
    • Just make it without a return value. Pros: More natural to low-level code. Cons: It is not so unified in syntax, and also, the vars are no longer immutable in the "unified" syntax.
    • Change the behavior of TIR function to make it has a return value and type. Pros: Unified and make the vars immutable. Cons: It is strange that a low-level code to return a buffer. Somehow a TIR function may have more than one return buffer or even change the input buffer.

I think both are not the perfect solution but I have no other ideas. Would love to listen to your opinions

Annotation

I'd like to add a kind of annotation in function to show whether it is a Relay or TIR function. One possible syntax would be

@relay
fn(%x, %y) { add(%x, %y) }
@tir
fn(%x, %y) { X[0] = Y[0] }

Function Body

There is no doubt that the relay function and TIR function have different stmt, expr. (e.g. Let Binding in relay and BufferLoad in TIR). So I think we can ignore the difference of the function body but try to make the function itself unified.

@spectrometerHBH @tqchen

@tqchen
Copy link
Contributor Author

tqchen commented Apr 7, 2020

  • re the signature: we do have signature for TIR functions now, see the implementation of the function signature here: mhttps://github.com/apache/incubator-tvm/blob/master/include/tvm/tir/function.h#L129
    • The default return type of a PrimFunc is void, which is represented by an empty tuple type(as in swift)
    • After the function is lowered to a PackedFunc, the return type could become int, both are indicated by the ret_type
  • re annotation: for now let us try to annotate the PrimFunc, and keep relay as default un-annotated(to preserve backward compact), besides the top level annotation, perhaps modifier keyword is another approach, think of inline void func, extern "c" int func, perhaps we could add a keyword(e.g. ) primfunc fn to indicate primfunc

@tqchen
Copy link
Contributor Author

tqchen commented Apr 16, 2020

FYI, I am in the process of refactoring out tensor field(func) in Realize Provide and replace them by BufferRealize BufferStore and BufferLoad. So we don’t need to consider supporting these nodes in the text format

@tqchen
Copy link
Contributor Author

tqchen commented Apr 20, 2020

Related change in the upstream apache/tvm#5372

@spectrometerHBH
Copy link
Collaborator

A comparison between repr's format and text format(which is problematic now)

[23:41:44] /home/spectrometer/tvm-upstream/src/te/schedule/bound.cc:128: not in feed graph consumer = compute(B, 0x1fb5cf0)
PrimFunc([A0, A1, C]) attrs={"tir.noalias": (bool)1, "global_symbol": "main"} {
  // attr [B.v0] storage_scope = "global"
  allocate B.v0[float32 * n]
  // attr [B.v1] storage_scope = "global"
  allocate B.v1[float32 * n]
  for (i, 0, m) {
    for (j, 0, n) {
      B.v0[j] = (A0[((i*stride) + (j*stride))] + 2f)
      B.v1[j] = (A0[((i*stride) + (j*stride))]*3f)
    }
    for (j, 0, n) {
      C[((i*stride) + (j*stride))] = (A1[((i*stride) + (j*stride))] + B.v0[j])
    }
  }
}


[23:41:44] /home/spectrometer/tvm-upstream/src/te/schedule/bound.cc:128: not in feed graph consumer = compute(B, 0x1fb5cf0)
class Module:
  PrimFunc(%A0: Buffer([%m, %n], "float32", name=%A01),
           %A1: Buffer([%m, %n], "float32", name=%A11),
           %C: Buffer([%m, %n], "float32", name=%C1)) {
    attr [%B.v0] "storage_scope" = "global" {
      allocate(%B.v0, "float32", [%n]) if "bool"(1) {
        attr [%B.v1] "storage_scope" = "global" {
          allocate(%B.v1, "float32", [%n]) if "bool"(1) {
            for (%i, 0, %m) "serial" {
              for (%j, 0, %n) "serial" {
                %B.v0[%j] = (%A0[((%i*%stride) + (%j*%stride1))] + "float32"(2)) if "bool"(1)
                %B.v1[%j] = (%A0[((%i*%stride) + (%j*%stride1))]*"float32"(3)) if "bool"(1)
              }
              for (%j1, 0, %n) "serial" {
                %C[((%i*%stride2) + (%j1*%stride3))] = (%A1[((%i*%stride4) + (%j1*%stride5))] + %B.v0[%j1]) if "bool"(1)
              }
            }
          }
        }
      }
    }
  }

Several Problems

1. The name of Vars

For readability, the name_hint of var is useful. But B.v0 may not be a legal identifier for parser since it contains .. And the user may give an arbitrary name.

2. dtype of Vars

I'm not very clear where Var is being used in the upstream AST. As long as I can see now

  1. used in buffer shape declare
  2. loop var
  3. stride var(stride above, but I'm not clear what they are)
  4. buffer var
  5. Iter var(but iter var seems to only appear in attr, otherwise we use the var in iter_var)

When printing a var, the best way is to just give its name_hint, but in principle we have to determine its dtype as well. But in most cases, dtype doesn't matter. I'm not sure about this point.

@tqchen
Copy link
Contributor Author

tqchen commented Apr 23, 2020

    1. Given that the name_hint is only a hint, we do not have to abide by the rule, we can rewrite . to _ or other symboles
    1. dtype should be printed during declaration(e.g. type annotation of let, for(where var is declared) ) but not at usage point

@tqchen
Copy link
Contributor Author

tqchen commented Apr 23, 2020

Some additional thoughts after seeing the example.

  • When dtype need to be printed, we can always print the Type of the(var), not dtype. Since Type always implies dtype, but not the other way around, we can use https://github.com/apache/incubator-tvm/blob/master/include/tvm/tir/op.h#L55 to obtain the type
    See discussion of dtype vs type here https://github.com/apache/incubator-tvm/blob/master/include/tvm/ir/type.h#L30
    • We can make the type printing consistent with relay
  • The for_type("serial") does not need to be printed(because it is the default), we might want to discuss how to print other functions
  • Constant literals "float32"(1) can be 1.0f, or float32(0)
  • When predicate is const_true(in allocate) do not print the predicate.
  • The % prefix seems to be quite dense and confusing when used together with operators(e.g. %), perhaps we want to allow variable names without the prefix?
  • Perhaps we should print attr differently (it is a bit strange to have string literal appear on the rhs)

A strawman, I am not too happy with the syntax

attr ("storage_scope", %Bv0) = "global"

@spectrometerHBH
Copy link
Collaborator

spectrometerHBH commented Apr 24, 2020

For vars like %m and %n, should we print their Type/dtype? Their declaration is outside AST.

%A0: Buffer([%m, %n], "float32", name=%A0_1),
%A1: Buffer([%m, %n], "float32", name=%A1_1),
%C: Buffer([%m, %n], "float32", name=%C_1)

One way is to print Type for the first time we encounter the Var

%A0: Buffer([%m : int32, %n : int32], "float32", name=%A0_1),
%A1: Buffer([%m, %n], "float32", name=%A1_1),
%C: Buffer([%m, %n], "float32", name=%C_1)

Another way is to declare first their defs somewhere else.

The third way is to put it in meta.

I think all the ways are somewhat strange.

And for vars like stride

%A0[((%i*%stride) + (%j*%stride1))

It looks like more a placeholder for unknown buffer shape. Does their Type/dtype matter?

@tqchen
Copy link
Contributor Author

tqchen commented Apr 24, 2020

I think the first encouter or declare def somewhere else makes sense. Formally, we should not put buffer_map in the argument list(that makes Buffer a type) but they are not types, we might want to have some seciton like c++'s initializer list (the corresponding hybrid script part is the buffer bind)

@spectrometerHBH
Copy link
Collaborator

It seems that relay's text printer doesn't support print PrimType and PointerType. Will this be fixed in the upstream?

@tqchen
Copy link
Contributor Author

tqchen commented Apr 24, 2020

I think it will be up to us to fix them :)

@spectrometerHBH
Copy link
Collaborator

spectrometerHBH commented Apr 25, 2020

Discussion on printing Buffer/Var, META usage&recover

An inevitable problem in text format is how to print buffer_map(Map<tir::Var, Buffer>). I will expand my discussions based on this problem. cc @tqchen @Hzfengsy

PrimFunc(%A0: handle, %A1: handle, %C: handle)
  %C: Buffer(%C_1, [%m, %n], float32, [%stride, %stride_1], %C_2)
  %A0: Buffer(%A0_1, [%m, %n], float32, [%stride_2, %stride_3], %A0_2)
  %A1: Buffer(%A1_1, [%m, %n], float32, [%stride_4, %stride_5], %A1_2) {
  attr [%B_v0] "storage_scope" = "global" {
    allocate(%B_v0, float32, [%n])  {
      attr [%B_v1] "storage_scope" = "global" {
        allocate(%B_v1, float32, [%n])  {
          for (%i, 0, %m)  {
            for (%j, 0, %n)  {
              %B_v0[%j] = (%A0_1[((%i*%stride_2) + (%j*%stride_3))] + float32(2))
              %B_v1[%j] = (%A0_1[((%i*%stride_2) + (%j*%stride_3))]*float32(3))
            }
            for (%j_1, 0, %n)  {
              %C_1[((%i*%stride) + (%j_1*%stride_1))] = (%A1_1[((%i*%stride_4) + (%j_1*%stride_5))] + %B_v0[%j_1])
            }
          }
        }
      }
    }
  }
}  

C0. Print all Buffer information out

Pros:

If we avoid using META in printing, then the text format is fully writable and modifiable to the user, as long as the user follows the syntax rules.

Cons:

It's really long to print all Buffer info, which hurts readability.

C1. Put Buffer in META

Pros:

Clearer for buffer_map printing

Cons:

  1. The user cannot write a text format freely since we use META.
  2. For vars like %m, %n, %stride, they are put into META along with Buffer, for their usage below, we have to print meta[Var][0], which makes the usage hard to read.

I further propose several ways to deal with the cons of C1.

C1.1 Scan the Vars in Buffer in advance and store a name map for them into META

When printing, we will give each Var we encouneter a unique name_hint. To be able to print stride instead of meta[Var][0]. We can store a Map<std:: string, Var> into META for all the vars appeared in Buffer declaration.
When parsing, we look up the name map when encounter stride to retrieve the right Var from META.
Cons: Hurts the writability of text format since buffer declarations lie in META.

C1.2 Scan all the Vars in AST and store a name map for them

Consider such a scenario in general cases:
A Var is declared first, but its usage is going to be put in META along with some other non-printable node containing it.
To print correctly, we have to collect all vars that are going to be put in META, and C1.1 is a special case to handle Vars in Buffer. If in the future we want to put more nodes into META, we have to keep this problem in mind.

But we can do this more aggressively to totally alleviate this problem, we can just scan all the Vars in AST and store a name map for them into META. Then for all the Var occurances, we can just print a name_hint no matter how we use this Var.

Pros: Simple and works generally. Note that it also works if we don't put the Buffer into META.
Cons: Hurts the writability most, since all the declaration of Vars lies in META.

@spectrometerHBH
Copy link
Collaborator

I will land a version that avoids using META then.

@spectrometerHBH
Copy link
Collaborator

Is GetType always return PrimType or PointerType? Is there an API to convert Type to dtype?

@tqchen
Copy link
Contributor Author

tqchen commented Apr 27, 2020

Type to dtype: https://github.com/apache/incubator-tvm/blob/master/include/tvm/tir/op.h#L70 (note that it is always safe to construct a Var with Type, because Type is more fine grained). It is fine to print a Var without Type(only with dtype) but load it back containing Type.

We cannot always deduce Type from dtype though

@spectrometerHBH
Copy link
Collaborator

btw, if the identifier doesn't begin with %, the lexer will confuse it with CNAME used in meta. So I will keep % atm.

@tqchen
Copy link
Contributor Author

tqchen commented Apr 27, 2020

I see, we will need to at least print mod as truncmod instead of % then.

@spectrometerHBH
Copy link
Collaborator

spectrometerHBH commented Apr 28, 2020

I've implemented a printer&parser for my current syntax. I'll attach the g4 file below for discussions in our next meeting. Note that it hasn't support META in Attr, and Attr(and DictAttr) is the only place that will result in nodes being put in META.

Several Points:

  1. load needs a dtype as a parameter of the constructor, and I haven't decided the syntax.
  2. I've tested several small test cases. But I have no clear idea of how to test it thoroughly. Now the way I can see is to test it using ir_builder. Also maybe I can try the operators provided in topi.
  3. buffers stores complete info of all the buffers appears in IR. But we can have more ideas on how to print it

@spectrometerHBH
Copy link
Collaborator

spectrometerHBH commented Apr 28, 2020

PrimFunc(%A0_1: handle, %A1_1: handle, %C_1: handle) -> ()
  buffers={%A1: Buffer(%A1_2: handle, float32, [%m: int32, %n: int32], [%stride: int32, %stride_1: int32], 0, "global", 128, 1, "auto"),
           %C: Buffer(%C_2: handle, float32, [%m, %n], [%stride_2: int32, %stride_3: int32], 0, "global", 128, 1, "auto"),
           %A0: Buffer(%A0_2: handle, float32, [%m, %n], [%stride_4: int32, %stride_5: int32], 0, "global", 128, 1, "auto")}
  buffer_map={%C_1: %C, %A0_1: %A0, %A1_1: %A1} {
  attr [%B_v0: handle] "storage_scope" = "global" {
    allocate(%B_v0, float32, [%n])  {
      attr [%B_v1: handle] "storage_scope" = "global" {
        allocate(%B_v1, float32, [%n])  {
          for (%i: int32, 0, %m)  {
            for (%j: int32, 0, %n)  {
              %B_v0[%j] = (%A0_2[((%i*%stride_4) + (%j*%stride_5))] + float32(2))
              %B_v1[%j] = (%A0_2[((%i*%stride_4) + (%j*%stride_5))]*float32(3))
            }
            for (%j_1: int32, 0, %n)  {
              %C_2[((%i*%stride_2) + (%j_1*%stride_3))] = (%A1_2[((%i*%stride) + (%j_1*%stride_1))] + %B_v0[%j_1])
            }
          }
        }
      }
    }
  }
}
grammar tir;

module
    : function* METADATA?
    ;

function
    : PRIMFUNC '(' varNode (',' varNode)* ')' '->' typeExpr
      'buffers' '=' '{' buffer_list?  '}'
      'buffer_map' '=' '{' buffer_map_list? '}'
       body
    ;

buffer_list
    : IDENTIFIER ':' bufferNode (',' IDENTIFIER ':' bufferNode)*
    ;

buffer_map_list
    : IDENTIFIER ':' IDENTIFIER (',' IDENTIFIER ':' IDENTIFIER)*
    ;

body
    : '{' stmtNode* '}'
    ;

node
    : primExprNode
    | stmtNode
    | rangeNode
    | commReducerNode
    | arrayNode
    | itervarNode
    | bufferNode
    | meta
    ;

arrayNode
    : '[' node (',' node)* ']'
    ;

stmtNode
    : LET varNode '=' primExprNode body # letStmt
    | ATTR '[' (varNode|itervarNode) ']' STRINGIMM '=' primExprNode body # attrStmt
    | ASSERT '(' primExprNode ',' primExprNode ')' body # assertStmt
    | varNode arrayNode '=' primExprNode (IF primExprNode)? # storeStmt
    | ALLOCATE '(' varNode ',' DTYPE ',' arrayNode ')' (IF primExprNode)? body # allocateStmt
    | FREE '(' varNode ')' # freeStmt
    | REALIZE '(' IDENTIFIER ',' arrayNode ')' (IF primExprNode)? body # bufferRealizeStmt
    | IF primExprNode body (ELSE body)? # ifThenElseStmt
    | EVALUATE '(' primExprNode ')' # evaluateStmt
    | FOR '(' varNode ',' primExprNode ',' primExprNode ')' STRINGIMM? body # forStmt
    | PREFETCH '(' IDENTIFIER ',' arrayNode ')' # prefetchStmt
    ;

primExprNode
    : CAST '(' DTYPE ',' primExprNode ')' # castExpr
    | FLOORDIV '(' primExprNode ',' primExprNode ')' # floordivExpr
    | FLOORMOD '(' primExprNode ',' primExprNode ')' # floormodExpr
    | MIN '(' primExprNode ',' primExprNode ')' # minExpr
    | MAX '(' primExprNode ',' primExprNode ')' # maxExpr
    | SELECT '(' primExprNode ',' primExprNode ',' primExprNode ')' # selectExpr
    | varNode arrayNode (IF primExprNode)? # loadExpr
    | RAMP '(' primExprNode ',' primExprNode ',' INT ')' # rampExpr
    | BROADCAST '(' primExprNode ',' INT ')' # broadcastExpr
    | LET varNode '=' primExprNode IN primExprNode # letExpr
    | CALL '(' DTYPE ',' STRINGIMM ',' arrayNode ',' STRINGIMM ',' INT ')' # callExpr
    | SHUFFLE '(' arrayNode ',' arrayNode ')' # shuffleExpr
    | REDUCE '(' commReducerNode ',' arrayNode ',' arrayNode ',' INT ')' (IF primExprNode)? # reduceExpr
    | <assoc=right> op='!' primExprNode # notExpr
    | src1 = primExprNode op=('*' | '/' | '%') src2 = primExprNode # binExpr
    | src1 = primExprNode op=('+' | '-') src2 = primExprNode # binExpr
    | src1 = primExprNode op=('<' | '>') src2 = primExprNode # binExpr
    | src1 = primExprNode op=('<=' | '>=') src2 = primExprNode # binExpr
    | src1 = primExprNode op=('==' | '!=') src2 = primExprNode # binExpr
    | src1 = primExprNode op='&&' src2 = primExprNode # binExpr
    | src1 = primExprNode op='||' src2 = primExprNode # binExpr
    | varNode # varExpr
    | meta # metaExpr
    | immediate # immExpr
    | '(' primExprNode ')' # parenExpr
    ;

varNode
    : IDENTIFIER (':' typeExpr)?
    ;

commReducerNode
    : COMMREDUCER '(' arrayNode ',' arrayNode ',' arrayNode ',' arrayNode ')'
    ;

bufferNode
    : BUFFER '(' varNode ',' DTYPE ',' arrayNode ',' arrayNode ',' primExprNode ','
      STRINGIMM ',' INT ',' INT ',' STRINGIMM ')'
    ;

rangeNode
    : primExprNode ':' primExprNode
    ;

itervarNode
    : ITERVAR '(' varNode ',' '[' rangeNode ']' ',' STRINGIMM ',' STRINGIMM ')'
    ;

immediate
    : STRINGIMM # stringImm
    | DTYPE '(' INT ')' # intImm
    | DTYPE '(' FLOATIMM ')' # floatImm
    | INT # int32Imm
    | TRUELITERAL # trueLiteral
    | FALSELITERAL # falseLiteral
    ;

meta
    : 'meta' '[' CNAME ']' '[' INT ']'
    ;

// --- Type
typeExpr
  : '(' ')'                                                                # tupleType
  | DTYPE                                                                  # primType
  | 'Pointer' '(' typeExpr ')'                                             # pointerType
  ;

METADATA: 'METADATA:' .*;

// --- Reserved words

MUL : '*' ;
DIV : '/' ;
ADD : '+' ;
SUB : '-' ;
MOD : '%' ;
LT : '<' ;
GT : '>' ;
LE : '<=' ;
GE : '>=' ;
EQ : '==' ;
NE : '!=' ;
AND : '&&' ;
OR : '||' ;

IN : 'in';
PRIMFUNC : 'PrimFunc';
IF : 'if';
ELSE : 'else';
TRUELITERAL : 'true';
FALSELITERAL : 'false';
LET : 'let';
ATTR : 'attr';
ASSERT : 'assert';
ALLOCATE : 'allocate';
FREE : 'free';
REALIZE : 'realize';
EVALUATE : 'evaluate';
FOR : 'for';
PREFETCH : 'prefetch';
CAST : 'cast';
FLOORDIV : 'floordiv';
FLOORMOD : 'floormod';
MIN : 'min';
MAX : 'max';
SELECT : 'select';
LOAD : 'load';
RAMP : 'ramp';
BROADCAST : 'broadcast';
CALL : 'call';
SHUFFLE : 'shuffle';
REDUCE : 'reduce';
BUFFER : 'Buffer';
COMMREDUCER : 'comm_reducer';
ITERVAR : 'IterVar';

fragment DIGIT
    :   [0-9]
    ;

fragment NAT
    :   DIGIT+
    ;

INT
    :   '-'? DIGIT+
    ;

fragment EXP
    :   [eE] [+\-]? NAT
    ;

FLOATIMM
    :   INT ('.' NAT)? EXP?
    ;

fragment LETTER
    :   [a-zA-Z]
    ;

STRINGIMM
    :   '"' ('\\n' | '\\\\' | '\\"' | .)*? '"'
    ;

DTYPE
    : 'float' NAT ('x' NAT)?
    | 'int' NAT ('x' NAT)?
    | 'uint' NAT ('x' NAT)?
    | 'bool'
    | 'handle'
    ;

IDENTIFIER
    :   '%' [a-zA-Z] + [a-zA-Z_0-9]*
    ;

CNAME : ('_'|LETTER) ('_'|LETTER|DIGIT)* ('.' CNAME)* ;

WHITESPACE
    :   [ \t\n\r]+ -> skip
    ;

@spectrometerHBH
Copy link
Collaborator

spectrometerHBH commented Apr 29, 2020

I tried to put vars declarations ahead of buffers, but I think it doesn' look very good.

vars = {int32: [%stride_3", "%stride_2", "%stride_4", "%stride_1", "%n", "%stride_5", "%m", "%stride", "%j]}

I think we'd better use the first encounter declaration style.

To shorten the declaration of a buffer, using kw param seems good

  buffers = {%A1: Buffer(%A1_2, float32, [%m : int32, %n : int32], [%stride_2 : int32, %stride_3 : int32]),
           %C: Buffer(%C_2, float32, [%m, %n], [%stride_4 : int32, %stride_5 : int32]),
           %A0: Buffer(%A0_2, float32, [%m, %n], [%stride : int32, %stride_1 : int32])}

@tqchen
Copy link
Contributor Author

tqchen commented Apr 29, 2020

how about something like vars= [%stride_3 : int32, %stride_0:int32] ?

The first encouter also looks ok. We could also introduce default types for indices(i32 and perhaps change to i64 in the future

@spectrometerHBH
Copy link
Collaborator

spectrometerHBH commented Apr 30, 2020

Example for Conv on GPU

primfn(A_1: handle, W_1: handle, B_1: handle) -> ()
  attr = {"tir.noalias": bool(1), "global_symbol": "main"}
  buffers = {B: Buffer(B_2: handle, float32, [14, 14, 512, 256], []),
           W: Buffer(W_2: handle, float32, [3, 3, 256, 512], []),
           A: Buffer(A_2: handle, float32, [14, 14, 256, 256], [])}
  buffer_map = {B_1: B, A_1: A, W_1: W} {
  attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
  attr [B.local: handle] "storage_scope" = "local";
  allocate(B.local, float32, [64])  {
    attr [Apad.shared: handle] "storage_scope" = "shared";
    allocate(Apad.shared, float32, [512])  {
      attr [W.shared: handle] "storage_scope" = "shared";
      allocate(W.shared, float32, [512])  {
        attr [Apad.shared.local: handle] "storage_scope" = "local";
        allocate(Apad.shared.local, float32, [8])  {
          attr [W.shared.local: handle] "storage_scope" = "local";
          allocate(W.shared.local, float32, [8])  {
            attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 8;
            attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 4;
            attr [IterVar(threadIdx.y: int32, [0:8], "ThreadIndex", "threadIdx.y")] "thread_extent" = 8;
            attr [IterVar(threadIdx.x: int32, [0:8], "ThreadIndex", "threadIdx.x")] "thread_extent" = 8 {
              for (ff.c.init: int32, 0, 4) {
                for (nn.c.init: int32, 0, 4) {
                  B.local[((ff.c.init*4) + nn.c.init)] = float32(0)
                  B.local[(((ff.c.init*4) + nn.c.init) + 32)] = float32(0)
                  B.local[(((ff.c.init*4) + nn.c.init) + 16)] = float32(0)
                  B.local[(((ff.c.init*4) + nn.c.init) + 48)] = float32(0)
                }
              }
              for (rc.outer: int32, 0, 32) {
                for (ry: int32, 0, 3) {
                  for (rx: int32, 0, 3) {
                    for (ax3.inner.outer: int32, 0, 2) {
                      Apad.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer*4)), 1, 4)] = call(float32x4, "tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + ry)) and ((floordiv(blockIdx.z, 14) + ry) < 15)) and (1 <= (rx + floormod(blockIdx.z, 14)))) and ((rx + floormod(blockIdx.z, 14)) < 15)), load(float32x4, A_2[ramp((((((((((ry*917504) + (blockIdx.z*65536)) + (rx*65536)) + (rc.outer*2048)) + (threadIdx.y*256)) + (blockIdx.x*64)) + (threadIdx.x*8)) + (ax3.inner.outer*4)) - 983040), 1, 4)]), broadcast(float32(0), 4)], "pure_intrin", 0)
                    }
                    for (ax3.inner.outer_1: int32, 0, 2) {
                      W.shared[ramp((((threadIdx.y*64) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)] = load(float32x4, W_2[ramp((((((((ry*393216) + (rx*131072)) + (rc.outer*4096)) + (threadIdx.y*512)) + (blockIdx.y*64)) + (threadIdx.x*8)) + (ax3.inner.outer_1*4)), 1, 4)])
                    }
                    for (rc.inner: int32, 0, 8) {
                      for (ax3: int32, 0, 4) {
                        Apad.shared.local[ax3] = load(float32, Apad.shared[(((rc.inner*64) + (threadIdx.x*4)) + ax3)])
                        Apad.shared.local[(ax3 + 4)] = load(float32, Apad.shared[((((rc.inner*64) + (threadIdx.x*4)) + ax3) + 32)])
                      }
                      for (ax3_1: int32, 0, 4) {
                        W.shared.local[ax3_1] = load(float32, W.shared[(((rc.inner*64) + (threadIdx.y*4)) + ax3_1)])
                        W.shared.local[(ax3_1 + 4)] = load(float32, W.shared[((((rc.inner*64) + (threadIdx.y*4)) + ax3_1) + 32)])
                      }
                      for (ff.c: int32, 0, 4) {
                        for (nn.c: int32, 0, 4) {
                          B.local[((ff.c*4) + nn.c)] = (load(float32, B.local[((ff.c*4) + nn.c)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[ff.c])))
                          B.local[(((ff.c*4) + nn.c) + 32)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 32)]) + (load(float32, Apad.shared.local[nn.c])*load(float32, W.shared.local[(ff.c + 4)])))
                          B.local[(((ff.c*4) + nn.c) + 16)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 16)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[ff.c])))
                          B.local[(((ff.c*4) + nn.c) + 48)] = (load(float32, B.local[(((ff.c*4) + nn.c) + 48)]) + (load(float32, Apad.shared.local[(nn.c + 4)])*load(float32, W.shared.local[(ff.c + 4)])))
                        }
                      }
                    }
                  }
                }
              }
              for (ff.inner.inner.inner: int32, 0, 4) {
                for (nn.inner.inner.inner: int32, 0, 4) {
                  B_2[(((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner)] = load(float32, B.local[((ff.inner.inner.inner*4) + nn.inner.inner.inner)])
                  B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8192)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 32)])
                  B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 32)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 16)])
                  B_2[((((((((blockIdx.z*131072) + (blockIdx.y*16384)) + (threadIdx.y*1024)) + (ff.inner.inner.inner*256)) + (blockIdx.x*64)) + (threadIdx.x*4)) + nn.inner.inner.inner) + 8224)] = load(float32, B.local[(((ff.inner.inner.inner*4) + nn.inner.inner.inner) + 48)])
                }
              }
            }
          }
        }
      }
    }
  }
}

@spectrometerHBH
Copy link
Collaborator

GEMM on CPU

primfn(A_1: handle, B_1: handle, C_1: handle) -> ()
  attr = {"tir.noalias": bool(1), "global_symbol": "main"}
  buffers = {C: Buffer(C_2: handle, float32, [1024, 1024], []),
           B: Buffer(B_2: handle, float32, [1024, 1024], []),
           A: Buffer(A_2: handle, float32, [1024, 1024], [])}
  buffer_map = {C_1: C, A_1: A, B_1: B} {
  attr [packedB: handle] "storage_scope" = "global";
  allocate(packedB, float32x32, [32768])  {
    for (x: int32, 0, 32) "parallel" {
      for (y: int32, 0, 1024) {
        packedB[ramp(((x*32768) + (y*32)), 1, 32)] = load(float32x32, B_2[ramp(((y*1024) + (x*32)), 1, 32)])
      }
    }
    for (x.outer: int32, 0, 32) "parallel" {
      attr [C.global: handle] "storage_scope" = "global";
      allocate(C.global, float32, [1024])  {
        for (y.outer: int32, 0, 32) {
          for (x.c.init: int32, 0, 32) {
            C.global[ramp((x.c.init*32), 1, 32)] = broadcast(float32(0), 32)
          }
          for (k.outer: int32, 0, 256) {
            for (x.c: int32, 0, 32) {
              C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[(((x.outer*32768) + (x.c*1024)) + (k.outer*4))]), 32)*load(float32x32, packedB[ramp(((y.outer*32768) + (k.outer*128)), 1, 32)])))
              C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 1)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 32), 1, 32)])))
              C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 2)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 64), 1, 32)])))
              C.global[ramp((x.c*32), 1, 32)] = (load(float32x32, C.global[ramp((x.c*32), 1, 32)]) + (broadcast(load(float32, A_2[((((x.outer*32768) + (x.c*1024)) + (k.outer*4)) + 3)]), 32)*load(float32x32, packedB[ramp((((y.outer*32768) + (k.outer*128)) + 96), 1, 32)])))
            }
          }
          for (x.inner: int32, 0, 32) {
            for (y.inner: int32, 0, 32) {
              C_2[((((x.outer*32768) + (x.inner*1024)) + (y.outer*32)) + y.inner)] = load(float32, C.global[((x.inner*32) + y.inner)])
            }
          }
        }
      }
    }
  }
}

@spectrometerHBH
Copy link
Collaborator

spectrometerHBH commented Apr 30, 2020

TensorCore for Conv

primfn(A_1: handle, W_1: handle, Conv_1: handle) -> ()
  attr = {"tir.noalias": bool(1), "global_symbol": "main"}
  buffers = {W: Buffer(W_2: handle, float16, [3, 3, 16, 32, 16, 16], []),
           A: Buffer(A_2: handle, float16, [16, 14, 14, 16, 16, 16], []),
           Conv: Buffer(Conv_2: handle, float32, [16, 14, 14, 32, 16, 16], [])}
  buffer_map = {A_1: A, Conv_1: Conv, W_1: W} {
  attr [IterVar(blockIdx.z: int32, [(nullptr)], "ThreadIndex", "blockIdx.z")] "thread_extent" = 196;
  attr [Conv.wmma.accumulator: handle] "storage_scope" = "wmma.accumulator";
  allocate(Conv.wmma.accumulator, float32, [2048])  {
    attr [Apad.shared: handle] "storage_scope" = "shared";
    allocate(Apad.shared, float16, [12288])  {
      attr [W.shared: handle] "storage_scope" = "shared";
      allocate(W.shared, float16, [12288])  {
        attr [Apad.shared.wmma.matrix_a: handle] "storage_scope" = "wmma.matrix_a";
        allocate(Apad.shared.wmma.matrix_a, float16, [512])  {
          attr [W.shared.wmma.matrix_b: handle] "storage_scope" = "wmma.matrix_b";
          allocate(W.shared.wmma.matrix_b, float16, [1024])  {
            attr [IterVar(blockIdx.x: int32, [(nullptr)], "ThreadIndex", "blockIdx.x")] "thread_extent" = 2;
            attr [IterVar(blockIdx.y: int32, [(nullptr)], "ThreadIndex", "blockIdx.y")] "thread_extent" = 4;
            attr [IterVar(threadIdx.y: int32, [(nullptr)], "ThreadIndex", "threadIdx.y")] "thread_extent" = 4;
            attr [IterVar(threadIdx.z: int32, [(nullptr)], "ThreadIndex", "threadIdx.z")] "thread_extent" = 2 {
              for (n.c.init: int32, 0, 2) {
                for (o.c.init: int32, 0, 4) {
                  eval(call("tvm_fill_fragment", [Conv.wmma.accumulator, 16, 16, 16, ((n.c.init*4) + o.c.init), float32(0)], handle, "intrin", 0))
                }
              }
              for (ic.outer: int32, 0, 8) {
                for (kh: int32, 0, 3) {
                  for (ax2: int32, 0, 3) {
                    for (ax3: int32, 0, 2) {
                      for (ax4.ax5.fused.outer: int32, 0, 8) {
                        attr [IterVar(threadIdx.x: int32, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
                        Apad.shared[((((((threadIdx.y*3072) + (threadIdx.z*1536)) + (ax2*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x)] = call("tvm_if_then_else", [((((1 <= (floordiv(blockIdx.z, 14) + kh)) and ((floordiv(blockIdx.z, 14) + kh) < 15)) and (1 <= (ax2 + floormod(blockIdx.z, 14)))) and ((ax2 + floormod(blockIdx.z, 14)) < 15)), load(float16, A_2[(((((((((((blockIdx.x*6422528) + (threadIdx.y*1605632)) + (threadIdx.z*802816)) + (kh*57344)) + (blockIdx.z*4096)) + (ax2*4096)) + (ic.outer*512)) + (ax3*256)) + (ax4.ax5.fused.outer*32)) + threadIdx.x) - 61440)]), float16(0)], float16, "pure_intrin", 0)
                      }
                    }
                  }
                  for (ax1: int32, 0, 3) {
                    for (ax2_1: int32, 0, 2) {
                      attr [IterVar(threadIdx.x, [(nullptr)], "ThreadIndex", "threadIdx.x")] "thread_extent" = 32;
                      W.shared[ramp((((((ax1*4096) + (ax2_1*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)] = load(float16x8, W_2[ramp(((((((((kh*393216) + (ax1*131072)) + (ic.outer*16384)) + (ax2_1*8192)) + (blockIdx.y*2048)) + (threadIdx.y*512)) + (threadIdx.z*256)) + (threadIdx.x*8)), 1, 8)])
                    }
                  }
                  for (ic.inner: int32, 0, 2) {
                    for (kw: int32, 0, 3) {
                      for (ax0: int32, 0, 2) {
                        eval(call("tvm_load_matrix_sync", [Apad.shared.wmma.matrix_a, 16, 16, 16, ax0, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), Apad.shared, ((((threadIdx.y*3072) + (ax0*1536)) + (kw*512)) + (ic.inner*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                      }
                      for (ax3_1: int32, 0, 4) {
                        eval(call("tvm_load_matrix_sync", [W.shared.wmma.matrix_b, 16, 16, 16, ax3_1, call("tvm_access_ptr", [call("type_annotation", [], float16, "pure_intrin", 0), W.shared, ((((kw*4096) + (ic.inner*2048)) + (threadIdx.z*1024)) + (ax3_1*256)), 256, 1], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                      }
                      for (n.c: int32, 0, 2) {
                        for (o.c: int32, 0, 4) {
                          eval(call("tvm_mma_sync", [Conv.wmma.accumulator, ((n.c*4) + o.c), Apad.shared.wmma.matrix_a, n.c, W.shared.wmma.matrix_b, o.c, Conv.wmma.accumulator, ((n.c*4) + o.c)], handle, "intrin", 0))
                        }
                      }
                    }
                  }
                }
              }
              for (n.inner: int32, 0, 2) {
                for (o.inner: int32, 0, 4) {
                  eval(call("tvm_store_matrix_sync", [Conv.wmma.accumulator, 16, 16, 16, ((n.inner*4) + o.inner), call("tvm_access_ptr", [call("type_annotation", [], float32, "pure_intrin", 0), Conv_2, (((((((blockIdx.x*12845056) + (threadIdx.y*3211264)) + (n.inner*1605632)) + (blockIdx.z*8192)) + (blockIdx.y*2048)) + (threadIdx.z*1024)) + (o.inner*256)), 256, 2], handle, "intrin", 0), 16, "row_major"], handle, "intrin", 0))
                }
              }
            }
          }
        }
      }
    }
  }
}

@tqchen tqchen closed this as completed Jun 26, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants