Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing reshape op so it supports reshaping of scalars #1322

Merged
merged 15 commits into from
Dec 4, 2024

Conversation

ajakovljevicTT
Copy link
Contributor

@ajakovljevicTT ajakovljevicTT commented Nov 19, 2024

As described in #1306, tt-xla produces some reshapes on scalars which are currently not supported by our stableHLO to TTIR dialect lowering (issue #1317). In addition, tt-metal currently throws an error when trying to reshape to a 1-dimensional shape (issue tenstorrent/tt-metal#15201).

To fix both of these issues and fully support scalar testing on tt-xla, I've ingested some of the changes from @mmanzoorTT 's change (#1319) - will merge this PR after it. In addition, I've added that the reshape lowering pass omits the ttir.reshape op completely if the types of the input op and the target shape are the same, which will fix the unnecessary calls to the ttnn reshape op from tt-xla.

Fixes #1306

Edit: In talks with @sdjordjevicTT, he has pointed out the existence of the folder for transform passes for TTIR ops, so I have changed my pass to live there.

@kmitrovicTT
Copy link
Contributor

Run

> pre-commit install
> pre-commit run --all-files

to fix failing check.

@mmanzoorTT
Copy link
Contributor

There is another PR (PR #1252 by @uazizTT) for handling the same issue for broadcast op (remove redundant ops). These two PRs are handling the same problem in two different location. Redundant broadcast op is removed in TTIR->TTIR conversion pass, whereas redundant reshape op is removed in StableHLO->TTIR conversion pass. These two PRs should remove redundant ops in a consistent way.

@mmanzoorTT
Copy link
Contributor

Please add some tests for these changes.

@ajakovljevicTT
Copy link
Contributor Author

ajakovljevicTT commented Nov 19, 2024

There is another PR (PR #1252 by @uazizTT) for handling the same issue for broadcast op (remove redundant ops). These two PRs are handling the same problem in two different location. Redundant broadcast op is removed in TTIR->TTIR conversion pass, whereas redundant reshape op is removed in StableHLO->TTIR conversion pass. These two PRs should remove redundant ops in a consistent way.

The reshape of here is remove in the TTIR->TTIR decomposition, which is currently a part of the larger TTIR->TTNN conversion, not StableHLO->TTIR. Will look into the referenced PR to get the idea how it is solved there.
I personally prefer the elimination to happen in the TTIR->TTNN part of the pipeline, as it then doesn't disturb the simple StableHLO->TTIR conversion that we agreed on.

@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/constant_fix branch from 4067b4e to ca82cbe Compare November 20, 2024 13:54
@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/constant_fix branch from 7877456 to 6d4ff2f Compare November 22, 2024 12:39
@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/constant_fix branch from 6d4ff2f to 03d62a5 Compare November 25, 2024 13:48
Copy link
Contributor

@azecevicTT azecevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why you decided to go with the whole pass, isn't the fold sufficient here?

@ajakovljevicTT
Copy link
Contributor Author

Curious why you decided to go with the whole pass, isn't the fold sufficient here?

@azecevicTT Just wanted to be in agreement with a PR that does a similar thing for broadcast #1252. Do you think it should be moved to a fold mechanism?

@azecevicTT
Copy link
Contributor

Curious why you decided to go with the whole pass, isn't the fold sufficient here?

@azecevicTT Just wanted to be in agreement with a PR that does a similar thing for broadcast #1252. Do you think it should be moved to a fold mechanism?

I can see, but I think fold is ideal for simple rewritings like this (and broadcast too). In the future we will probably have 10s if not 100s of this kind of rewrting patterns, having a separate pass for all of them can become burdensome from a maintainability standpoint, so it will be best to use an existing infrastructure, or group them under one pass that will operate on all ops, at the moment fold seems like the best option.

@ajakovljevicTT
Copy link
Contributor Author

ajakovljevicTT commented Nov 29, 2024

Curious why you decided to go with the whole pass, isn't the fold sufficient here?

@azecevicTT Just wanted to be in agreement with a PR that does a similar thing for broadcast #1252. Do you think it should be moved to a fold mechanism?

I can see, but I think fold is ideal for simple rewritings like this (and broadcast too). In the future we will probably have 10s if not 100s of this kind of rewrting patterns, having a separate pass for all of them can become burdensome from a maintainability standpoint, so it will be best to use an existing infrastructure, or group them under one pass that will operate on all ops, at the moment fold seems like the best option.

@azecevicTT Added folding and tested if it works on our current setup with the test in the PR. It seems to be working, so added the ommission as a fold instead of a pass. Thanks for pointing that out!

@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/constant_fix branch 2 times, most recently from 7ab3f08 to 3c4d9f7 Compare December 2, 2024 07:45
@ajakovljevicTT
Copy link
Contributor Author

Hi all, did any of the people who took a look at this PR (or anyone else), have the time to take a second look at this and leave comments and/or approve the PR? Thanks!

Copy link
Contributor

@sdjordjevicTT sdjordjevicTT left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice and elegant! Thanks.

@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/constant_fix branch 2 times, most recently from 4008902 to fdfde28 Compare December 3, 2024 13:43
@ajakovljevicTT ajakovljevicTT force-pushed the ajakovljevic/constant_fix branch from fdfde28 to 1736871 Compare December 3, 2024 15:23
@ajakovljevicTT ajakovljevicTT merged commit 824b256 into main Dec 4, 2024
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

TT-xla not supporting scalar values
10 participants