-
Notifications
You must be signed in to change notification settings - Fork 13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Handle max_pool2d input dataformat float32 #1389
Comments
Yeah, I agree with you @dgolubovicTT. We will take a look at it, thanks for reporting! :) |
I've made tvm pattern callback that inserts cast before and after each max_pool2d and it solved the issue. This was just to test it locally. I still believe that forge (tvm included) shouldn't be aware of ttnn constraints. So now the more general question arises:
|
@dgolubovicTT I agree with you, you shouldn't worry about the data format constraints from Forge-FE. Some TTNN ops can automatically do the cast, but some ops fail if the inputs aren't in the specified data format. We have defined the issue on our side to handle such cases: For each workaround that we introduce in the tt-mlir stack, we are filing an issue on the metal side to track the resolution. Once the issue is resolved on the metal side, we will remove the workarounds from the compiler code. |
This sounds promising. Thanks! |
When I run resnet test from forge I get unexpeted error:
loc("max_pool2d_17"("forward":4294967295:3591)): error: 'ttnn.max_pool2d' op ttnn.max_pool2d currently only supports an input type of bfloat16. Recieved 'f32'.
Turns out it is due to assert in verify
mlir::tt::ttnn::Conv2dOp::verify()
added in PR.So If ttnn.max_pool2d only supports bfloat16 we shouldn't just fail compile if its input is float32. We should probably add a cast op to handle this and move on with compile.
@sdjordjevicTT can we prioritize this because it is a blocker for ResNet bringup?
fyi @nvukobratTT
The text was updated successfully, but these errors were encountered: