Add conversion for StableHLO scatter op to TTIR and TTNN dialect #1279
+338
−3
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
As part of the effort to run stableHLO models on TT silicon, lowering scatter op through TTIR and TTNN dialects.
This op currently has limited functionality, but enough to pass the current models, primarily due to the current limitations in tt-metal:
The op only supports one index, mimicking Torch's
select_scatter
op https://pytorch.org/docs/stable/generated/torch.select_scatter.htmlFurthermore, since we do not currently support any indexing on the tt-metal side (Limitations in TT-Scatter Op tt-metal#4294), it will ignore the ttir index variable and use index 0, as supported by tt-metal.
The StableHLO scatter op also supports adding a custom function for merging the two tensor operands of the scatter. Since this is not supported in tt-metal, I have added the check to see if the function is just mapping from one tensor to another (the default
select_scatter
behaviour) and assert on that. I have opened a related tt-mlir issue (Support of custom functions for the scatter op in TTNN dialect. #1278), but not a tt-metal one.Relevant TTIR dialect issue: Add ttir.scatter op to the TTIR dialect #1325
Edit:
As advised in the comments changed the runtime pipeline of the scatter op to go through the
composite_binary
path, omitting any additional (currently not used because they are not supported by tt-metal) op parameters and attributes.