From 7ec2b44aadbb6bd094dd356ad4a7b83c9bec25e8 Mon Sep 17 00:00:00 2001 From: Chen-Pang He Date: Wed, 16 Oct 2024 17:35:45 +0000 Subject: [PATCH] #13385: Low-level interface for `ttnn.round` on Wormhole --- .../metal/llk_api/llk_math_unary_sfpu_api.h | 2 +- .../llk_math_eltwise_unary_sfpu_round.h | 27 +++++++++++ .../metal/llk_api/llk_sfpu_types.h | 3 +- .../compute_kernel_api/eltwise_unary/round.h | 45 +++++++++++++++++++ .../eltwise_unary/sfpu_split_includes.h | 4 ++ 5 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_round.h create mode 100644 tt_metal/include/compute_kernel_api/eltwise_unary/round.h diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h index 33489e62388..806ec8e404b 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_math_unary_sfpu_api.h @@ -20,13 +20,13 @@ #include "llk_math_eltwise_unary_sfpu_signbit.h" #include "llk_math_eltwise_unary_sfpu_floor.h" #include "llk_math_eltwise_unary_sfpu_ceil.h" +#include "llk_math_eltwise_unary_sfpu_round.h" #include "llk_math_eltwise_unary_sfpu_silu.h" #include "llk_math_eltwise_unary_sfpu_square.h" #include "llk_math_eltwise_unary_sfpu_tanh.h" #include "llk_math_eltwise_unary_sfpu_topk.h" #include "llk_math_eltwise_unary_sfpu_unary_comp.h" #include "llk_math_eltwise_unary_sfpu_trigonometry.h" -#include "llk_math_eltwise_unary_sfpu_unary_comp.h" #include "llk_math_eltwise_unary_sfpu_remainder.h" #include "llk_math_eltwise_unary_sfpu_bitwise_xor.h" #include "llk_math_eltwise_unary_sfpu_bitwise_not.h" diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_round.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_round.h new file mode 100644 index 00000000000..f9500a95d55 --- /dev/null +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/llk_math_eltwise_unary_sfpu_round.h @@ -0,0 +1,27 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + +#include "llk_math_eltwise_unary_sfpu_init.h" +#include "llk_math_eltwise_unary_sfpu_params.h" +#include "ckernel_sfpu_round.h" + +namespace ckernel { + +// New LLK SFPU APIs + +template +inline void llk_math_eltwise_unary_sfpu_round_init() { + llk_math_eltwise_unary_sfpu_init(); +} + +template +inline void llk_math_eltwise_unary_sfpu_round(uint dst_index, int vector_mode = (int)VectorMode::RC) { + llk_math_eltwise_unary_sfpu_params + (ckernel::sfpu::calculate_round, + dst_index, vector_mode); +} + +} diff --git a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h index 266e1641173..428d7b9b33d 100644 --- a/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h +++ b/tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h @@ -87,5 +87,6 @@ enum SfpuType { ceil, unused, reshuffle_rows, - cumsum + cumsum, + round, }; diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/round.h b/tt_metal/include/compute_kernel_api/eltwise_unary/round.h new file mode 100644 index 00000000000..4c1b049da03 --- /dev/null +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/round.h @@ -0,0 +1,45 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#pragma once + + +#include "compute_kernel_api/common_globals.h" +#ifdef TRISC_MATH +#include "llk_math_eltwise_unary_sfpu_round.h" +#define MAIN math_main() +#define MATH(x) x +#else +#define MATH(x) +#endif + + + +namespace ckernel { + +/** + * Please refer to documentation for any_init. + */ +ALWI void round_tile_init() { + MATH(( llk_math_eltwise_unary_sfpu_round_init() )); +} + +/** + * Performs round operation on each row of a tile. + * in DST register at index tile_index. The DST register buffer must be in + * acquired state via *acquire_dst* call. This call is blocking and is only + * available on the compute engine. + * + * Return value: None + * + * | Argument | Description | Type | Valid Range | Required | + * |-----------------|----------------------------------------------------------------------------|----------|-------------------------------------------------------|----------| + * | idst | The index of the tile in DST register buffer to modify the sign bit of | uint32_t | Must be less than the size of the DST register buffer | True | + */ +ALWI void round_tile(uint32_t idst) { + MATH(( llk_math_eltwise_unary_sfpu_round(idst) )); +} + + +} // namespace ckernel diff --git a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h index 8a54ca1e7b6..31fd9343490 100644 --- a/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h +++ b/tt_metal/include/compute_kernel_api/eltwise_unary/sfpu_split_includes.h @@ -92,6 +92,10 @@ #include "compute_kernel_api/eltwise_unary/floor.h" #endif +#if SFPU_OP_ROUND_INCLUDE +#include "compute_kernel_api/eltwise_unary/round.h" +#endif + #if SFPU_OP_LEFT_SHIFT_INCLUDE #include "compute_kernel_api/eltwise_unary/left_shift.h" #endif