Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize ttnn.round with a direct implementation #13385

Open
Tracked by #13795
jdh8 opened this issue Oct 2, 2024 · 5 comments · May be fixed by #13851
Open
Tracked by #13795

Optimize ttnn.round with a direct implementation #13385

jdh8 opened this issue Oct 2, 2024 · 5 comments · May be fixed by #13851
Assignees
Labels
feature-request External feature request MCW op_cat: eltwise perf for issues tracking performance problems/improvements

Comments

@jdh8
Copy link
Contributor

jdh8 commented Oct 2, 2024

Rounding is only supported by Wormhole, and Wormhole, and Wormhole has the exact function float_to_int16 if the value is in range.
https://github.com/tenstorrent/tt-metal/blob/main/docs/source/tt-metalium/tt_metal/apis/kernel_apis/sfpu/llk.rst#wormhole-only

However, ttnn.round is implemented as a combination of ttnn.floor, ttnn.add, etc.

} else { // Bankers' Rounding
Tensor rounded_non_half = ttnn::floor(
ttnn::add(
input,
ttnn::where(ttnn::logical_and(ttnn::ge(input, 0.4), ttnn::le(input, 0.5)), 0.4f, 0.5f, output_mem_config.value()),
std::nullopt,
output_mem_config),
output_mem_config.value());
Tensor fractional_part = ttnn::subtract(input, floor_res, std::nullopt, output_mem_config);
Tensor is_half = ttnn::eq(fractional_part, 0.5, std::nullopt, output_mem_config);
Tensor rounded_half =
ttnn::add(floor_res, is_odd(floor_res, output_mem_config), std::nullopt, output_mem_config);
return ttnn::where(is_half, rounded_half, rounded_non_half, output_mem_config.value());
}

Then in turn, ttnn.floor calls functions that effectively computes ttnn.round.

vInt tmp = float_to_int16(result, 0); //TODO: Replace float_to_int16 to float_to_int32 once it is available
result = int32_to_float(tmp, 0);

Rounding to a nearest integer is extremely useful for argument reduction. We can reuse a direct implementation in other mathematical functions (mostly elementary functions) such as:

  • Exponential functions
  • Trigonometric functions
  • ttnn.pow
@jdh8 jdh8 added the feature-request External feature request label Oct 2, 2024
@jdh8 jdh8 added op_cat: eltwise perf for issues tracking performance problems/improvements and removed community labels Oct 15, 2024
@jdh8 jdh8 linked a pull request Oct 16, 2024 that will close this issue
8 tasks
jdh8 added a commit that referenced this issue Oct 16, 2024
@mouliraj-mcw
Copy link
Contributor

Hi @jdh8 ,
I examined your approach and found that it doesn't address rounding to a specific number of decimal places (i.e., 2 or 3 decimal places).
Could you please share your thoughts on how this could be managed?

@jdh8
Copy link
Contributor Author

jdh8 commented Oct 17, 2024

Thanks for pointing it out! I missed the parameter decimals.

It can be managed with multiplication by 10n. To be specific,

round(x, n) = 10**-n * round(10**n * x)

@jdh8
Copy link
Contributor Author

jdh8 commented Oct 17, 2024

I have two proposals:

  1. Implement a native roundeven(x) as conceptually round(x, 0), and then make round(x, n) on top of roundeven. (Named after C23 roundeven)
  2. Make a direct, native round(x, n).

Which approach looks better?

@mouliraj-mcw
Copy link
Contributor

I think approach two would be more suitable, as it has a straightforward structure.

jdh8 added a commit that referenced this issue Oct 24, 2024
jdh8 added a commit that referenced this issue Oct 24, 2024
jdh8 added a commit that referenced this issue Oct 24, 2024
jdh8 added a commit that referenced this issue Oct 25, 2024
jdh8 added a commit that referenced this issue Oct 27, 2024
jdh8 added a commit that referenced this issue Oct 27, 2024
jdh8 added a commit that referenced this issue Oct 27, 2024
jdh8 added a commit that referenced this issue Oct 27, 2024
@eyonland eyonland added the MCW label Dec 20, 2024
@umadevimcw
Copy link
Contributor

umadevimcw commented Jan 7, 2025

@jdh8 Tested the rounding in the jdh8/direct-rounding branch with reference to this comment #13851 (review) and observed that

torch round uses banker's rounding algorithm where it follows round nearest even approach, for example

for the input 94.5

Torch result is 94 whereas in
TT the result is 95 which results in test case failure

94.5 is halfway between 94 and 95 so hence the value is rounded to the nearest even and the result becomes 94 which needs to handled in our TT implementation

Please find the image below (Red is TT's output and Green is Torch output)

Image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature-request External feature request MCW op_cat: eltwise perf for issues tracking performance problems/improvements
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants