-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Inline all functions that are called. #898
Conversation
277be35
to
de1e98a
Compare
de1e98a
to
2696df8
Compare
Out of curiosity, will forge ever generate something like below?
|
registry.addExtension<TTIRDialect>( | ||
+[](MLIRContext *ctx, TTIRDialect *dialect) { | ||
dialect->addInterfaces<TTIRInlinerInterface>(); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this have to be defined as an extension? Or could you just do
addInterfaces<TTIRInlinerInterface>();
at the dialect initialization.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I could do that. But for FuncDialect I’m forced to register it since its initialize method is inside the toolchain. So, I thought it would be cleaner to just explicitly register them both rather than register only the FuncDialect Inliner and add the TTIRDialect Inliner in Initialize()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see your point. I am not sure should we do it this way, as extensions are used to extend some dialect that is already built for you. Maybe I am missing knowledge on dialect extensions. @sdjordjevicTT what do you think?
Btw func inliner depends on CF dialect, and creates CF branch op.
void handleTerminator(Operation *op, Block *newDest) const final {
...
builder.create<cf::BranchOp>(op->getLoc(), newDest, returnOp.getOperands());
are we sure this op won't be inserted in our IR? We have no lowering for CF dialect for any backend.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that only hits for functions defined inside other functions or as lambdas to certain ops. Which I guess could be a problem. Another idea is to create our own Call
, Func
, and Return
ops in TTIR so we don't need to rely on the Func
dialect at all, and we can just use the Inliner I've provided. When lowering to TTIR from whatever frontend we'd just need to replace the Func::FuncOp
generation with TTIR::FuncOp
etc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I can also define my own Inliner for the func dialect (which wouldn't have the overloaded handleTerminator
that creates cf
ops) and register that instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does seem like arith dialect uses the interface approach, can we experiment with that @LPanosTT?
namespace {
/// This class defines the interface for handling inlining for arithmetic
/// dialect operations.
struct ArithInlinerInterface : public DialectInlinerInterface {
using DialectInlinerInterface::DialectInlinerInterface;
/// All arithmetic dialect ops can be inlined.
bool isLegalToInline(Operation *, Region *, bool, IRMapping &) const final {
return true;
}
};
} // namespace
void arith::ArithDialect::initialize() {
...
addInterfaces<ArithInlinerInterface>();
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah torch mlir does it this way too: https://github.com/llvm/torch-mlir/blob/8787970afed3c4e1497fb24c4fdeec179fcb61f6/lib/Dialect/Torch/IR/TorchDialect.cpp#L32
I think it seems like it can be done with only changes to TTIRDialect.cpp
. I'm not sure we need to register the inliner on behalf of func.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See their InitAll for reference too https://github.com/llvm/torch-mlir/blob/8787970afed3c4e1497fb24c4fdeec179fcb61f6/lib/InitAll.cpp#L13.
It seems like all they do is:
- Implement the dialect interface
- addInterfaces
- Register func dialect inliner extension
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nsmithtt it looks like they do register the inliner on behalf of func though
void mlir::torch::registerAllExtensions(mlir::DialectRegistry ®istry) {
mlir::func::registerInlinerExtension(registry);
tensor::registerInferTypeOpInterfaceExternalModels(registry);
}
I can move the registration of the TTIR Inliner interface to TTIRDialect.cpp and use addInterfaces
I've tried and it works. So I'll do that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@nsmithtt @rpavlovicTT Should I create my own inliner for func? One that doesn't override the handleTerminator
for Block
? Just so that we leave the cf
dialect entirely out of the picture.
@mtopalovicTT I don’t think forge will since we only have one forge graph and no 'func' ops. However JAX certainly can (I.e relu can be put in a separate function defined as a maximum against 0). I think it depends on the actual structure of the model in JAX and if there are functions that are called in the python definition |
7451218
to
4da241c
Compare
registry.addExtension<TTIRDialect>( | ||
+[](MLIRContext *ctx, TTIRDialect *dialect) { | ||
dialect->addInterfaces<TTIRInlinerInterface>(); | ||
}); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah torch mlir does it this way too: https://github.com/llvm/torch-mlir/blob/8787970afed3c4e1497fb24c4fdeec179fcb61f6/lib/Dialect/Torch/IR/TorchDialect.cpp#L32
I think it seems like it can be done with only changes to TTIRDialect.cpp
. I'm not sure we need to register the inliner on behalf of func.
132b8c5
to
bce63ef
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @LPanosTT! This looks great
bce63ef
to
e0b227f
Compare
@LPanosTT, I think I have a build fix over here: I've only tested on macOS so far, still need to get ubuntu to sign off and ideally one of the FE CI's to run against it. |
e0b227f
to
816aa8a
Compare
816aa8a
to
66df189
Compare
66df189
to
03dd47e
Compare
fyi @wooseokTT |
Added pass that will make-private all functions which are calledThe inliners will inline all functions - however it will leave those functions in the module if they are public even when they are dead code. Setting them to private ensures they are deleted from the module.Since we are inlining all functions that are called, if the module has multiple distinct programs they will all remain afterward. This is because the topmost function in the call stack is the only non-dead-code function that doesn't get called.
Disabling the inliner pass is as simple as removing the added lines in
lib/Dialect/TTNN/Pipelines/TTNNPipelines.cpp
If you wish to not register them altogether you'll have to delete their invocations in
lib/RegisterAll.cpp::registerAllExtensions