Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle max_pool2d input dataformat float32 #1389

Closed
dgolubovicTT opened this issue Nov 25, 2024 · 4 comments · Fixed by #1657
Closed

Handle max_pool2d input dataformat float32 #1389

dgolubovicTT opened this issue Nov 25, 2024 · 4 comments · Fixed by #1657
Assignees

Comments

@dgolubovicTT
Copy link
Contributor

When I run resnet test from forge I get unexpeted error:

loc("max_pool2d_17"("forward":4294967295:3591)): error: 'ttnn.max_pool2d' op ttnn.max_pool2d currently only supports an input type of bfloat16. Recieved 'f32'.

Turns out it is due to assert in verify mlir::tt::ttnn::Conv2dOp::verify() added in PR.

So If ttnn.max_pool2d only supports bfloat16 we shouldn't just fail compile if its input is float32. We should probably add a cast op to handle this and move on with compile.

@sdjordjevicTT can we prioritize this because it is a blocker for ResNet bringup?
fyi @nvukobratTT

@sdjordjevicTT
Copy link
Contributor

Yeah, I agree with you @dgolubovicTT. We will take a look at it, thanks for reporting! :)

@dgolubovicTT
Copy link
Contributor Author

I've made tvm pattern callback that inserts cast before and after each max_pool2d and it solved the issue. This was just to test it locally. I still believe that forge (tvm included) shouldn't be aware of ttnn constraints.
There are more operations on ttnn that require bfloat16 on inputs or even require bfloat16 weights. For example embedding op requires emb. weights to be bfloat16.

So now the more general question arises:

  1. Default data format in torch is float32 and many of the model weights are in that format
  2. tt-metal obviously works at bfloat16, but lets inputs be in float32 and then implicitly casts them. Sometimes, ttnn ops require that inputs or weights are in bfloat16. So potentially we will have to do back and forth casts throughout the graph. This is ok for the start, but we should require ttnn to accept float32 and cast it implicitly as it does in other ops (ttnn.add).

@sdjordjevicTT
Copy link
Contributor

@dgolubovicTT I agree with you, you shouldn't worry about the data format constraints from Forge-FE.

Some TTNN ops can automatically do the cast, but some ops fail if the inputs aren't in the specified data format. We have defined the issue on our side to handle such cases:
#1433

For each workaround that we introduce in the tt-mlir stack, we are filing an issue on the metal side to track the resolution. Once the issue is resolved on the metal side, we will remove the workarounds from the compiler code.

@dgolubovicTT
Copy link
Contributor Author

This sounds promising. Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants