From 054a38da0d6346ad292460514bfa6479b6b2ce14 Mon Sep 17 00:00:00 2001 From: Yan Zaretskiy Date: Wed, 23 Oct 2024 18:40:04 +0000 Subject: [PATCH] #13644: Add support for tensor-scalar binary ops --- .clang-tidy | 2 +- ...r_bcast_scalar_interleaved_partitioned.cpp | 67 +++++++ .../ttnn/operations/eltwise/binary/binary.cpp | 98 +++++----- .../ttnn/operations/eltwise/binary/binary.hpp | 8 +- .../binary/device/binary_device_operation.cpp | 174 ++++++++++++------ .../binary/device/binary_device_operation.hpp | 48 +++-- ...t_and_width_multi_core_program_factory.cpp | 157 ++++++++-------- ...cast_height_multi_core_program_factory.cpp | 12 +- ...core_sharded_optimized_program_factory.cpp | 12 +- ...ght_multi_core_sharded_program_factory.cpp | 14 +- ...dcast_width_multi_core_program_factory.cpp | 12 +- ...lement_wise_multi_core_program_factory.cpp | 16 +- 12 files changed, 380 insertions(+), 240 deletions(-) create mode 100644 ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/dataflow/reader_bcast_scalar_interleaved_partitioned.cpp diff --git a/.clang-tidy b/.clang-tidy index c8b42ce70d1e..03cf529f05c1 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -3,7 +3,7 @@ Checks: > performance-*, modernize-*, readability-*, - cppcoreguidelines-* + cppcoreguidelines-*, -modernize-use-trailing-return-type CheckOptions: diff --git a/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/dataflow/reader_bcast_scalar_interleaved_partitioned.cpp b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/dataflow/reader_bcast_scalar_interleaved_partitioned.cpp new file mode 100644 index 000000000000..7f35fa443e73 --- /dev/null +++ b/ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/dataflow/reader_bcast_scalar_interleaved_partitioned.cpp @@ -0,0 +1,67 @@ +// SPDX-FileCopyrightText: © 2023 Tenstorrent Inc. +// +// SPDX-License-Identifier: Apache-2.0 + +#include +#include "dataflow_api.h" +#include "ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/dataflow/generate_bcast_scalar.hpp" + + + +void kernel_main() { + auto src0_addr = get_arg_val(0); + auto packed_scalar = get_arg_val(1); + auto num_tiles = get_arg_val(2); + auto HtWt = get_arg_val(3); + auto base_start_id_HtWt = get_arg_val(4); + auto curr_id_from_base = get_arg_val(5); + auto bcast_id = get_arg_val(6); + + #ifndef IN0_SHARDED + constexpr bool src0_is_dram = get_compile_time_arg_val(0) == 1; + #endif + + constexpr uint32_t cb_id_in0 = 0; + constexpr uint32_t cb_id_in1 = 1; + constexpr uint32_t onetile = 1; + + // single-tile ublocks + const uint32_t in0_tile_bytes = get_tile_size(cb_id_in0); + const DataFormat in0_data_format = get_dataformat(cb_id_in0); + const DataFormat in1_data_format = DataFormat::Float16_b; + + uint32_t l1_write_addr_in0; + uint32_t l1_write_addr_in1; + + #ifndef IN0_SHARDED + const InterleavedAddrGenFast s0 = { + .bank_base_address = src0_addr, + .page_size = in0_tile_bytes, + .data_format = in0_data_format + }; + #else + cb_reserve_back(cb_id_in0, num_tiles); + cb_push_back(cb_id_in0, num_tiles); + #endif + + generate_bcast_unary_scalar(cb_id_in1, packed_scalar); + + for (uint32_t i = 0; i < num_tiles; i++) { + uint32_t curr_id = base_start_id_HtWt + curr_id_from_base; + + #ifndef IN0_SHARDED + cb_reserve_back(cb_id_in0, onetile); + l1_write_addr_in0 = get_write_ptr(cb_id_in0); + noc_async_read_tile(curr_id, s0, l1_write_addr_in0); + noc_async_read_barrier(); + cb_push_back(cb_id_in0, onetile); + #endif + + curr_id_from_base++; + + if (curr_id_from_base == HtWt) { + base_start_id_HtWt += HtWt; + curr_id_from_base = 0; + } + } +} diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp index af6ee7605c67..b9956a45fadd 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.cpp @@ -98,11 +98,8 @@ inline Tensor binary_impl( return output_tensor; } -template -auto preprocess_inputs( - const Tensor& input_tensor_a_arg, - const Tensor& input_tensor_b_arg) { - +template +auto preprocess_inputs(const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg) { Tensor input_tensor_a = input_tensor_a_arg; Tensor input_tensor_b = input_tensor_b_arg; @@ -149,8 +146,8 @@ Tensor BinaryOperation::invoke( std::optional optional_output_tensor, std::optional activations, std::optional input_tensor_a_activation) { - - auto [input_tensor_a, input_tensor_b] = detail::preprocess_inputs(input_tensor_a_arg, input_tensor_b_arg); + auto [input_tensor_a, input_tensor_b] = + detail::preprocess_inputs(input_tensor_a_arg, input_tensor_b_arg); return ttnn::prim::binary( queue_id, @@ -184,53 +181,44 @@ Tensor BinaryOperation::invoke( input_tensor_a_activation); } -// TODO: this case should use BinaryWithScalarProgramConfig and there should be a custom kernel to run this -// Currently, this is exactly how tt::tt_metal::add_unary works template Tensor BinaryOperation::invoke( + uint8_t queue_id, const ttnn::Tensor &input_tensor_a, - const float scalar, - const std::optional &dtype, + float scalar, + const std::optional &output_dtype, const std::optional &memory_config, const std::optional &optional_output_tensor, std::optional activations, std::optional input_tensor_a_activation) { - return BinaryOperation::invoke( - DefaultQueueId, + return ttnn::prim::binary( + queue_id, input_tensor_a, scalar, - dtype, + binary_op_type, + output_dtype, memory_config, optional_output_tensor, activations, input_tensor_a_activation); } +// TODO: this case should use BinaryWithScalarProgramConfig and there should be a custom kernel to run this +// Currently, this is exactly how tt::tt_metal::add_unary works template Tensor BinaryOperation::invoke( - uint8_t queue_id, const ttnn::Tensor &input_tensor_a, - const float scalar, - const std::optional &dtype, + float scalar, + const std::optional &output_dtype, const std::optional &memory_config, const std::optional &optional_output_tensor, std::optional activations, std::optional input_tensor_a_activation) { - using namespace tt::constants; - // Cast Float Scalar to a device tensor - auto host_buffer = owned_buffer::create<::bfloat16>(static_cast(TILE_HEIGHT * TILE_WIDTH)); - host_buffer[0] = scalar; - Tensor scalar_tensor_host = Tensor( - OwnedStorage{host_buffer}, - ttnn::Shape(std::array{1, 1}, std::array{TILE_HEIGHT, TILE_WIDTH}), - DataType::BFLOAT16, - Layout::TILE); - Tensor scalar_tensor_device = scalar_tensor_host.to(input_tensor_a.device()); - // TODO(arakhmati): #7637 pass in memory_config instead of operation::DEFAULT_OUTPUT_MEMORY_CONFIG return BinaryOperation::invoke( + DefaultQueueId, input_tensor_a, - scalar_tensor_device, - dtype, + scalar, + output_dtype, memory_config, optional_output_tensor, activations, @@ -253,7 +241,8 @@ Tensor RelationalBinary::invoke( "If both output dtype and output tensor provided dtype should match"); } - auto [input_tensor_a, input_tensor_b] = detail::preprocess_inputs(input_tensor_a_arg, input_tensor_b_arg); + auto [input_tensor_a, input_tensor_b] = + detail::preprocess_inputs(input_tensor_a_arg, input_tensor_b_arg); auto output_memory_config = memory_config.value_or(input_tensor_a.memory_config()); DataType dtype = output_dtype.value_or(input_tensor_a.get_dtype()); @@ -334,25 +323,34 @@ Tensor RelationalBinary::invoke( template Tensor InplaceRelationalBinary::invoke( - const Tensor &input_tensor_a_arg, - const Tensor &input_tensor_b_arg) { - - return RelationalBinary::invoke(input_tensor_a_arg, input_tensor_b_arg, std::nullopt, std::nullopt, input_tensor_a_arg, std::nullopt, std::nullopt); + const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg) { + return RelationalBinary::invoke( + input_tensor_a_arg, + input_tensor_b_arg, + std::nullopt, + std::nullopt, + input_tensor_a_arg, + std::nullopt, + std::nullopt); } template -Tensor InplaceRelationalBinary::invoke( - const ttnn::Tensor &input_tensor_a, - const float scalar) { - return RelationalBinary::invoke(input_tensor_a, scalar, std::nullopt, std::nullopt, input_tensor_a, std::nullopt, std::nullopt); +Tensor InplaceRelationalBinary::invoke(const ttnn::Tensor &input_tensor_a, const float scalar) { + return RelationalBinary::invoke( + input_tensor_a, scalar, std::nullopt, std::nullopt, input_tensor_a, std::nullopt, std::nullopt); } template Tensor InplaceLogicalBinary::invoke( - const Tensor &input_tensor_a_arg, - const Tensor &input_tensor_b_arg) { - - return BinaryOperation::invoke(input_tensor_a_arg, input_tensor_b_arg, std::nullopt, std::nullopt, input_tensor_a_arg, std::nullopt, std::nullopt); + const Tensor &input_tensor_a_arg, const Tensor &input_tensor_b_arg) { + return BinaryOperation::invoke( + input_tensor_a_arg, + input_tensor_b_arg, + std::nullopt, + std::nullopt, + input_tensor_a_arg, + std::nullopt, + std::nullopt); } template @@ -361,8 +359,14 @@ Tensor InplaceBinaryOperation::invoke( const Tensor &input_tensor_b_arg, std::optional activations, std::optional input_tensor_a_activation) { - - return BinaryOperation::invoke(input_tensor_a_arg, input_tensor_b_arg, std::nullopt, std::nullopt, input_tensor_a_arg, activations, input_tensor_a_activation); + return BinaryOperation::invoke( + input_tensor_a_arg, + input_tensor_b_arg, + std::nullopt, + std::nullopt, + input_tensor_a_arg, + activations, + input_tensor_a_activation); } template @@ -371,7 +375,8 @@ Tensor InplaceBinaryOperation::invoke( const float scalar, std::optional activations, std::optional input_tensor_a_activation) { - return BinaryOperation::invoke(input_tensor_a, scalar, std::nullopt, std::nullopt, input_tensor_a, activations, input_tensor_a_activation); + return BinaryOperation::invoke( + input_tensor_a, scalar, std::nullopt, std::nullopt, input_tensor_a, activations, input_tensor_a_activation); } template struct BinaryOperation; @@ -403,7 +408,6 @@ template struct InplaceRelationalBinary; template struct InplaceRelationalBinary; template struct InplaceRelationalBinary; - template struct InplaceLogicalBinary; template struct InplaceLogicalBinary; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp index 66b7fe2ff496..2f73a760ae55 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/binary.hpp @@ -43,8 +43,8 @@ struct BinaryOperation { // Currently, this is exactly how tt::tt_metal::add_unary works static Tensor invoke( const ttnn::Tensor &input_tensor_a, - const float scalar, - const std::optional &dtype = std::nullopt, + float scalar, + const std::optional &output_dtype = std::nullopt, const std::optional &memory_config = std::nullopt, const std::optional &optional_output_tensor = std::nullopt, std::optional activations = std::nullopt, @@ -53,8 +53,8 @@ struct BinaryOperation { static Tensor invoke( uint8_t queue_id, const ttnn::Tensor &input_tensor_a, - const float scalar, - const std::optional &dtype = std::nullopt, + float scalar, + const std::optional &output_dtype = std::nullopt, const std::optional &memory_config = std::nullopt, const std::optional &optional_output_tensor = std::nullopt, std::optional activations = std::nullopt, diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index e42f7e72d70d..97e0c943f51c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -5,9 +5,9 @@ #include "binary_device_operation.hpp" #include "tt_metal/common/constants.hpp" +#include "tt_metal/common/work_split.hpp" #include "tt_metal/host_api.hpp" #include "ttnn/operations/data_movement/bcast/bcast.hpp" -#include "tt_metal/common/work_split.hpp" namespace ttnn::operations::binary { @@ -15,7 +15,12 @@ BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_f const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { ZoneScopedN("BinaryDeviceOperation::select_program_factory"); const auto& input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape; - const auto& input_shape_b = tensor_args.input_tensor_b.tensor_attributes->shape; + + if (operation_attributes.scalar.has_value()) { + return BroadcastHeightAndWidthMultiCore{}; + } + + const auto& input_shape_b = tensor_args.input_tensor_b->tensor_attributes->shape; auto height_a = input_shape_a[-2]; auto width_a = input_shape_a[-1]; @@ -25,21 +30,24 @@ BinaryDeviceOperation::program_factory_t BinaryDeviceOperation::select_program_f if (height_a == height_b and width_a == width_b) { return ElementWiseMultiCore{}; - } else if (height_b == 1 or width_b == 1) { + } + if (height_b == 1 or width_b == 1) { if (height_b == 1 and width_b == 1) { return BroadcastHeightAndWidthMultiCore{}; - } else if (height_b == 1) { - if(tensor_args.input_tensor_a.is_sharded()){ - if (tensor_args.input_tensor_a.get_legacy_shape()[0] == tensor_args.input_tensor_b.get_legacy_shape()[0] - || tensor_args.input_tensor_a.get_legacy_shape()[0] > 1 - and tensor_args.input_tensor_b.get_legacy_shape()[0] == 1){ - return BroadcastHeightMultiCoreShardedOptimized{}; - } else { - return BroadcastHeightMultiCoreSharded{}; + } + if (height_b == 1) { + if (tensor_args.input_tensor_a.is_sharded()) { + if (tensor_args.input_tensor_a.get_padded_shape()[0] == + tensor_args.input_tensor_b->get_padded_shape()[0] || + tensor_args.input_tensor_a.get_padded_shape()[0] > 1 and + tensor_args.input_tensor_b->get_padded_shape()[0] == 1) { + return BroadcastHeightMultiCoreShardedOptimized{}; } + return BroadcastHeightMultiCoreSharded{}; } return BroadcastHeightMultiCore{}; - } else if (width_b == 1) { + } + if (width_b == 1) { return BroadcastWidthMultiCore{}; } } @@ -53,64 +61,71 @@ void BinaryDeviceOperation::validate_on_program_cache_miss( const auto& input_tensor_b = tensor_args.input_tensor_b; const auto& output_tensor = tensor_args.output_tensor; + TT_FATAL(input_tensor_b.has_value() != attributes.scalar.has_value(), "Either the tensor b or scalar should be set"); + BinaryDeviceOperation::validate_on_program_cache_hit(attributes, tensor_args); - TT_FATAL( - input_tensor_a.device() == input_tensor_b.device(), - "Operands to eltwise binary need to be on the same device!"); - TT_FATAL( - (input_tensor_a.get_layout() == Layout::TILE && input_tensor_b.get_layout() == Layout::TILE), - "Inputs to eltwise binary must be tilized"); - // Only case when op is not valid is if we have different shardings in any of 2 inputs and output - not supported at the momment + TT_FATAL(input_tensor_a.get_layout() == Layout::TILE, "Input to eltwise binary must be tilized"); + + bool tensor_b_sharded = false; + + if (input_tensor_b.has_value()) { + tensor_b_sharded = input_tensor_b->memory_config().is_sharded(); + TT_FATAL( + input_tensor_a.device() == input_tensor_b->device(), + "Operands to eltwise binary need to be on the same device!"); + TT_FATAL(input_tensor_b->get_layout() == Layout::TILE, "Inputs to eltwise binary must be tilized"); + } + if (input_tensor_a.memory_config().is_sharded()) { - if (input_tensor_b.memory_config().is_sharded()) { - TT_FATAL(input_tensor_a.memory_config().memory_layout == input_tensor_b.memory_config().memory_layout, "Error"); - TT_FATAL(input_tensor_a.shard_spec().value() == input_tensor_b.shard_spec().value(), "Error"); + if (tensor_b_sharded) { + TT_FATAL( + input_tensor_a.memory_config().memory_layout == input_tensor_b->memory_config().memory_layout, "Error"); + TT_FATAL(input_tensor_a.shard_spec().value() == input_tensor_b->shard_spec().value(), "Error"); } if (attributes.memory_config.is_sharded()) { TT_FATAL(input_tensor_a.memory_config().memory_layout == attributes.memory_config.memory_layout, "Error"); } else { TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); } - } else if (input_tensor_b.memory_config().is_sharded()) { + } else if (tensor_b_sharded) { TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); if (attributes.memory_config.is_sharded()) { - TT_FATAL(input_tensor_b.memory_config().memory_layout == attributes.memory_config.memory_layout, "Error"); + TT_FATAL(input_tensor_b->memory_config().memory_layout == attributes.memory_config.memory_layout, "Error"); } else { TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); } } else { TT_FATAL(input_tensor_a.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); - TT_FATAL(input_tensor_b.memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); + TT_FATAL( + !input_tensor_b.has_value() or + (input_tensor_b->memory_config().memory_layout == TensorMemoryLayout::INTERLEAVED), + "Error"); if (!attributes.memory_config.is_sharded()) { TT_FATAL(attributes.memory_config.memory_layout == TensorMemoryLayout::INTERLEAVED, "Error"); } } auto program_factory = select_program_factory(attributes, tensor_args); - std::visit( - [&attributes](auto&& program_factory) { - if constexpr (std::is_same_v) { - TT_FATAL(not attributes.activations.has_value(), "Error"); - } - }, - program_factory); - + if (std::holds_alternative(program_factory)) { + TT_FATAL(not attributes.activations.has_value(), "Error"); + } } + void BinaryDeviceOperation::validate_on_program_cache_hit( const operation_attributes_t& attributes, const tensor_args_t& tensor_args) { const auto& input_tensor_a = tensor_args.input_tensor_a; - const auto& input_tensor_b = tensor_args.input_tensor_b; const auto& output_tensor = tensor_args.output_tensor; const auto& input_shape_a = input_tensor_a.get_shape(); - const auto& input_shape_b = input_tensor_b.get_shape(); auto batch_size_0_a = input_shape_a.rank() >= 4 ? input_shape_a[-4] : 1; auto batch_size_1_a = input_shape_a.rank() >= 3 ? input_shape_a[-3] : 1; auto height_a = input_shape_a[-2]; auto width_a = input_shape_a[-1]; + const auto input_shape_b = + tensor_args.input_tensor_b.has_value() ? tensor_args.input_tensor_b->get_shape() : ttnn::Shape{1, 1}; auto batch_size_0_b = input_shape_b.rank() >= 4 ? input_shape_b[-4] : 1; auto batch_size_1_b = input_shape_b.rank() >= 3 ? input_shape_b[-3] : 1; auto height_b = input_shape_b[-2]; @@ -140,7 +155,8 @@ void BinaryDeviceOperation::validate_on_program_cache_hit( BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes( const operation_attributes_t&, const tensor_args_t& tensor_args) { const auto input_shape_a = tensor_args.input_tensor_a.tensor_attributes->shape; - const auto input_shape_b = tensor_args.input_tensor_b.tensor_attributes->shape; + const auto& tensor_b = tensor_args.input_tensor_b; + const auto input_shape_b = tensor_b.has_value() ? tensor_b->tensor_attributes->shape : ttnn::Shape{1, 1}; const int rank_a = input_shape_a.rank(); const int rank_b = input_shape_b.rank(); @@ -188,13 +204,11 @@ BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_outpu return ttnn::Shape(output_shape, output_shape_with_tile_padding); } - BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { using namespace tt::constants; auto output_shape = compute_output_shapes(operation_attributes, tensor_args); const auto& input_tensor_a = tensor_args.input_tensor_a; - const auto& input_tensor_b = tensor_args.input_tensor_b; const auto& output_tensor = tensor_args.output_tensor; if (output_tensor.has_value()) { @@ -203,6 +217,7 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu auto program_factory = select_program_factory(operation_attributes, tensor_args); if (std::holds_alternative(program_factory)) { + const auto& input_tensor_b = *tensor_args.input_tensor_b; if (operation_attributes.memory_config.is_sharded()) { ShardSpec shard_spec{CoreRangeSet(), {0, 0}}; if (input_tensor_a.memory_config().is_sharded()) { @@ -236,7 +251,6 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu Layout::TILE, input_tensor_a.device(), operation_attributes.memory_config); - } tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash( @@ -249,19 +263,27 @@ tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash( std::holds_alternative(input_tensor_a.get_storage()), "Unexpected type {}", tt::stl::get_active_type_name_in_variant(input_tensor_a.get_storage())); - TT_ASSERT( - std::holds_alternative(input_tensor_b.get_storage()), - "Unexpected type {}", - tt::stl::get_active_type_name_in_variant(input_tensor_b.get_storage())); - operation::Hash hash = operation::hash_operation( + if (input_tensor_b.has_value()) { + TT_ASSERT( + std::holds_alternative(input_tensor_b->get_storage()), + "Unexpected type {}", + tt::stl::get_active_type_name_in_variant(input_tensor_b->get_storage())); + + return operation::hash_operation( + attributes, + program_factory.index(), + input_tensor_a.dtype(), + std::get(input_tensor_a.storage()).memory_config(), + input_tensor_b->dtype(), + std::get(input_tensor_b->storage()).memory_config()); + } + + return operation::hash_operation( attributes, program_factory.index(), input_tensor_a.dtype(), - std::get(input_tensor_a.storage()).memory_config(), - input_tensor_b.dtype(), - std::get(input_tensor_b.storage()).memory_config()); - return hash; + std::get(input_tensor_a.storage()).memory_config()); } operation::OpPerformanceModel BinaryDeviceOperation::create_op_performance_model( @@ -274,15 +296,19 @@ operation::OpPerformanceModel BinaryDeviceOperation::create_op_performance_model // GS specific parameters // 80 B/cycle unpacker BW shared // 128 datums per cycle math, but unpacker cant keep up - constexpr int num_cores = 9 * 12; + constexpr uint32_t num_cores = 9 * 12; - int total_bytes = 0; + uint32_t total_bytes = 0; + std::vector input_tensors = {input_tensor_a}; total_bytes += input_tensor_a.volume() * input_tensor_a.element_size(); - total_bytes += input_tensor_b.volume() * input_tensor_b.element_size(); - int ideal_eltwise_cycles = total_bytes / 80 / num_cores; + if (input_tensor_b.has_value()) { + input_tensors.push_back(*input_tensor_b); + total_bytes += input_tensor_b->volume() * input_tensor_b->element_size(); + } + uint32_t ideal_eltwise_cycles = total_bytes / 80 / num_cores; // TODO: update OpPerformanceModel to work on variadic arguments - operation::OpPerformanceModel result({input_tensor_a, input_tensor_b}, {output_tensor}, ideal_eltwise_cycles); + operation::OpPerformanceModel result(input_tensors, {output_tensor}, ideal_eltwise_cycles); #if 0 tt::log_info(tt::LogOp, "BinaryDeviceOperation PerfModel:"); tt::log_info(tt::LogOp, "\t Data (Bytes): {}", total_bytes); @@ -291,14 +317,13 @@ operation::OpPerformanceModel BinaryDeviceOperation::create_op_performance_model return result; } - - -std::tuple BinaryDeviceOperation::invoke( - const Tensor &input_tensor_a_arg, - const Tensor &input_tensor_b_arg, +std::tuple +BinaryDeviceOperation::invoke( + const Tensor& input_tensor_a_arg, + const Tensor& input_tensor_b_arg, BinaryOpType binary_op_type, - const std::optional &output_dtype, - const std::optional &memory_config, + const std::optional& output_dtype, + const std::optional& memory_config, std::optional optional_output_tensor, std::optional activations, std::optional input_tensor_a_activation) { @@ -313,10 +338,39 @@ std::tuple +BinaryDeviceOperation::invoke( + const Tensor& input_tensor_a_arg, + float scalar, + BinaryOpType binary_op_type, + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + std::optional activations, + std::optional input_tensor_a_activation) { + if (output_dtype.has_value() && optional_output_tensor.has_value()) { + TT_FATAL( + output_dtype.value() == optional_output_tensor.value().get_dtype(), + "If both output dtype and output tensor provided dtype should match"); } + return { + operation_attributes_t{ + binary_op_type, + activations, + input_tensor_a_activation, + scalar, + memory_config.value_or(input_tensor_a_arg.memory_config()), + output_dtype.value_or(input_tensor_a_arg.get_dtype()), + std::nullopt}, + tensor_args_t{input_tensor_a_arg, std::nullopt, optional_output_tensor}}; +} + } // namespace ttnn::operations::binary diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp index bef06ac83792..75d108d9c357 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp @@ -5,25 +5,24 @@ #pragma once #include +#include #include #include -#include "ttnn/common/constants.hpp" -#include "ttnn/tensor/tensor.hpp" -#include -#include "ttnn/tensor/host_buffer/functions.hpp" -#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" -#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" -#include "ttnn/run_operation.hpp" #include "tt_metal/host_api.hpp" #include "tt_metal/impl/dispatch/command_queue.hpp" +#include "ttnn/common/constants.hpp" #include "ttnn/core.hpp" #include "ttnn/decorators.hpp" #include "ttnn/device_operation.hpp" -#include "ttnn/types.hpp" +#include "ttnn/operations/core/compute_kernel/compute_kernel_config.hpp" #include "ttnn/operations/eltwise/binary/common/binary_op_types.hpp" #include "ttnn/operations/eltwise/binary/common/binary_op_utils.hpp" -#include "ttnn/decorators.hpp" +#include "ttnn/operations/eltwise/unary/common/unary_op_types.hpp" +#include "ttnn/run_operation.hpp" +#include "ttnn/tensor/host_buffer/functions.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/types.hpp" namespace ttnn::operations::binary { @@ -32,13 +31,20 @@ struct BinaryDeviceOperation { BinaryOpType binary_op_type; const std::optional activations; const std::optional input_tensor_a_activation; + const std::optional scalar; const MemoryConfig memory_config; const DataType dtype; std::optional compute_kernel_config; + + tt::stl::hash::hash_t to_hash() const { + // hash has to exclude the scalar value + return tt::stl::hash::hash_objects_with_default_seed( + binary_op_type, activations, input_tensor_a_activation, memory_config, dtype, compute_kernel_config); + } }; struct tensor_args_t { const Tensor& input_tensor_a; - const Tensor& input_tensor_b; + const std::optional& input_tensor_b; std::optional output_tensor; }; using shape_return_value_t = ttnn::Shape; @@ -197,14 +203,12 @@ struct BinaryDeviceOperation { static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes( - const operation_attributes_t&, const tensor_args_t&); + static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t&); - static tt::stl::hash::hash_t compute_program_hash( - const operation_attributes_t&, const tensor_args_t&); + static tt::stl::hash::hash_t compute_program_hash(const operation_attributes_t&, const tensor_args_t&); static operation::OpPerformanceModel create_op_performance_model( const operation_attributes_t& attributes, @@ -220,11 +224,21 @@ struct BinaryDeviceOperation { std::optional optional_output_tensor, std::optional activations, std::optional input_tensor_a_activation); + + static std::tuple invoke( + const Tensor& input_tensor_a_arg, + float scalar, + BinaryOpType binary_op_type, + const std::optional& output_dtype, + const std::optional& memory_config, + std::optional optional_output_tensor, + std::optional activations, + std::optional input_tensor_a_activation); }; } // namespace ttnn::operations::binary - namespace ttnn::prim { -constexpr auto binary = ttnn::register_operation<"ttnn::prim::binary", ttnn::operations::binary::BinaryDeviceOperation>(); -} // namespace ttnn::prim +constexpr auto binary = + ttnn::register_operation<"ttnn::prim::binary", ttnn::operations::binary::BinaryDeviceOperation>(); +} // namespace ttnn::prim diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp index a964c226bd6c..f427c9090cfe 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_and_width_multi_core_program_factory.cpp @@ -6,15 +6,14 @@ #include "binary_device_operation.hpp" #include "impl/buffers/buffer.hpp" -#include "ttnn/tensor/tensor.hpp" -#include "ttnn/operations/data_movement/bcast/bcast.hpp" -#include "tt_metal/common/work_split.hpp" #include "tt_metal/common/constants.hpp" +#include "tt_metal/common/work_split.hpp" #include "tt_metal/detail/util.hpp" #include "tt_metal/host_api.hpp" #include "ttnn/device_operation.hpp" -#include "tt_metal/common/constants.hpp" - +#include "ttnn/operations/data_movement/bcast/bcast.hpp" +#include "ttnn/tensor/tensor.hpp" +#include "ttnn/operations/cb_utils.hpp" namespace ttnn::operations::binary { @@ -40,16 +39,14 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( const auto& b = tensor_args.input_tensor_b; auto& output = tensor_return_value; auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); - const auto ashape = a.get_legacy_shape(); - const auto bshape = b.get_legacy_shape(); + const auto ashape = a.get_padded_shape(); + const auto bshape = b.has_value() ? b->get_padded_shape() : Shape{1, 1}; uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; uint32_t W = ashape[-1]; uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1; uint32_t bC = bshape.rank() >= 3 ? bshape[-3] : 1; - uint32_t bH = bshape[-2]; - uint32_t bW = bshape[-1]; uint32_t NC = N * C; uint32_t HW = H * W; @@ -59,9 +56,9 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( uint32_t num_tensor_tiles = NC * Ht * Wt; - uint32_t bnc1 = (bN * bC == 1); + bool bnc1 = (bN * bC == 1); - tt_metal::Program program = tt_metal::CreateProgram(); + auto program = tt_metal::CreateProgram(); tt_metal::Device* device = a.device(); @@ -75,7 +72,8 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( } tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat src1_cb_data_format = + b.has_value() ? tt_metal::datatype_to_dataformat_converter(b->get_dtype()) : tt::DataFormat::Float16_b; tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); @@ -88,65 +86,43 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( uint32_t num_cores_total = num_cores_x * num_cores_y; auto all_device_cores = CoreRange({0, 0}, {num_cores_x - 1, num_cores_y - 1}); - bool row_major = false; - if (shard_spec.has_value()) { - row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; - } + bool row_major = shard_spec.has_value() ? shard_spec->orientation == ShardOrientation::ROW_MAJOR : false; + auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles, row_major); auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); - auto src0_buffer = a.buffer(); - auto src1_buffer = b.buffer(); - auto dst_buffer = output.buffer(); + auto* src0_buffer = a.buffer(); + auto* src1_buffer = b.has_value() ? b->buffer() : nullptr; + auto* dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); - uint32_t src0_cb_index = 0; uint32_t num_input_tiles = 2; uint32_t num_tiles_per_shard = 0; if (shard_spec.has_value()) { - num_tiles_per_shard = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; + num_tiles_per_shard = shard_spec->shape[0] * shard_spec->shape[1] / TILE_HW; num_tiles_per_core_group_1 = num_tiles_per_shard; num_tiles_per_core_group_2 = 0; - all_cores = shard_spec.value().grid; + all_cores = shard_spec->grid; core_group_1 = all_cores; core_group_2 = CoreRangeSet(); } uint32_t num_input_tiles_cb0 = src0_sharded ? num_tiles_per_shard : num_input_tiles; - tt_metal::CircularBufferConfig src0_cb_config = - tt_metal::CircularBufferConfig( - num_input_tiles_cb0 * src0_single_tile_size, {{src0_cb_index, src0_cb_data_format}}) - .set_page_size(src0_cb_index, src0_single_tile_size); - if (src0_sharded) { - src0_cb_config = src0_cb_config.set_globally_allocated_address(*a.buffer()); - } - auto cb_src0 = tt_metal::CreateCircularBuffer(program, all_device_cores, src0_cb_config); + auto* cb_src0_buffer = src0_sharded ? src0_buffer : nullptr; + auto [cb_src0, cb_handle_src0] = create_cb(tt::CB::c_in0, program, all_device_cores, src0_single_tile_size, num_input_tiles_cb0, src0_cb_data_format, cb_src0_buffer); - uint32_t src1_cb_index = 1; - tt_metal::CircularBufferConfig src1_cb_config = - tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}}) - .set_page_size(src1_cb_index, src1_single_tile_size); - auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, src1_cb_config); + uint32_t num_input_tiles_cb1 = src1_buffer != nullptr ? num_input_tiles : 1; + create_cb(tt::CB::c_in1, program, all_device_cores, src1_single_tile_size, num_input_tiles_cb1, src1_cb_data_format); - uint32_t output_cb_index = 16; // output operands start at index 16 uint32_t num_output_tiles = output_sharded ? num_tiles_per_shard : 2; - tt_metal::CircularBufferConfig output_cb_config = - tt_metal::CircularBufferConfig(num_output_tiles * dst_single_tile_size, {{output_cb_index, dst_cb_data_format}}) - .set_page_size(output_cb_index, dst_single_tile_size); - if (output_sharded) { - output_cb_config = output_cb_config.set_globally_allocated_address(*output.buffer()); - } - auto cb_output = tt_metal::CreateCircularBuffer(program, all_device_cores, output_cb_config); - - bool src0_is_dram = src0_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - std::vector reader_compile_time_args = {(uint32_t)src0_is_dram, (uint32_t)src1_is_dram}; + auto* cb_output_buffer = output_sharded ? dst_buffer : nullptr; + auto [cb_output, cb_handle_output] = create_cb(tt::CB::c_out0, program, all_device_cores, dst_single_tile_size, num_output_tiles, dst_cb_data_format, cb_output_buffer); - bool dst_is_dram = dst_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; - std::vector writer_compile_time_args = {(std::uint32_t)output_cb_index, (std::uint32_t)dst_is_dram}; + auto src0_is_dram = static_cast(src0_buffer->buffer_type() == tt_metal::BufferType::DRAM); + auto dst_is_dram = static_cast(dst_buffer->buffer_type() == tt_metal::BufferType::DRAM); std::map reader_defines; std::map bcast_compute_defines = bcast_op_utils::get_defines(BcastOpDim::HW, bcast_math); @@ -157,11 +133,25 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( if (src0_sharded) { reader_defines["IN0_SHARDED"] = "1"; } - KernelHandle binary_reader_kernel_id = tt_metal::CreateKernel( - program, - "ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/dataflow/reader_bcast_hw_interleaved_partitioned.cpp", - all_device_cores, - tt_metal::ReaderDataMovementConfig(reader_compile_time_args, reader_defines)); + + KernelHandle binary_reader_kernel_id{}; + + if (src1_buffer != nullptr) { + auto src1_is_dram = static_cast(src1_buffer->buffer_type() == tt_metal::BufferType::DRAM); + binary_reader_kernel_id = tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/dataflow/" + "reader_bcast_hw_interleaved_partitioned.cpp", + all_device_cores, + tt_metal::ReaderDataMovementConfig({src0_is_dram, src1_is_dram}, reader_defines)); + } else { + binary_reader_kernel_id = tt_metal::CreateKernel( + program, + "ttnn/cpp/ttnn/operations/data_movement/bcast/device/kernels/dataflow/" + "reader_bcast_scalar_interleaved_partitioned.cpp", + all_device_cores, + tt_metal::ReaderDataMovementConfig({src0_is_dram}, reader_defines)); + } std::map writer_defines; if (output_sharded) { @@ -171,7 +161,7 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( program, "ttnn/cpp/ttnn/operations/eltwise/unary/device/kernels/dataflow/writer_unary_interleaved_start_id.cpp", all_device_cores, - tt_metal::WriterDataMovementConfig(writer_compile_time_args, writer_defines)); + tt_metal::WriterDataMovementConfig({cb_output, dst_is_dram}, writer_defines)); auto bcast_kernel_id = tt_metal::CreateKernel( program, @@ -193,17 +183,24 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( continue; } - tt_metal::SetRuntimeArgs( - program, - binary_reader_kernel_id, - core, - {a.buffer()->address(), // 0 - b.buffer()->address(), - num_tensor_tiles_per_core, - HtWt, - num_tiles_read / HtWt * HtWt, - num_tiles_read % HtWt, - bnc1 ? 0 : num_tiles_read / HtWt}); + std::vector binary_reader_args = { + src0_buffer->address(), // 0 + 0, + num_tensor_tiles_per_core, + HtWt, + num_tiles_read / HtWt * HtWt, + num_tiles_read % HtWt, + bnc1 ? 0 : num_tiles_read / HtWt}; + + if (src1_buffer != nullptr) { + binary_reader_args[1] = src1_buffer->address(); + } else { + class bfloat16 bfloat_scalar(*operation_attributes.scalar); + uint32_t packed_scalar = pack_two_bfloat16_into_uint32({bfloat_scalar, bfloat_scalar}); + binary_reader_args[1] = packed_scalar; + } + + tt_metal::SetRuntimeArgs(program, binary_reader_kernel_id, core, binary_reader_args); tt_metal::SetRuntimeArgs( program, @@ -220,7 +217,7 @@ BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::create( unary_writer_kernel_id, core, { - output.buffer()->address(), + dst_buffer->address(), num_tensor_tiles_per_core, num_tiles_read, }); @@ -269,7 +266,6 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a uint32_t num_cores_total = num_cores_x * num_cores_y; auto src_buffer_a = input_tensor_a.buffer(); - auto src_dram_buffer_b = input_tensor_b.buffer(); std::optional shard_spec = std::nullopt; bool src0_sharded = input_tensor_a.memory_config().is_sharded(); bool out_sharded = output_tensor.memory_config().is_sharded(); @@ -282,16 +278,14 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a auto dst_buffer = output_tensor.buffer(); - const auto ashape = input_tensor_a.get_legacy_shape(); - const auto bshape = input_tensor_b.get_legacy_shape(); + const auto ashape = input_tensor_a.get_padded_shape(); + const auto bshape = input_tensor_b.has_value() ? input_tensor_b->get_padded_shape() : Shape{1, 1}; uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; uint32_t W = ashape[-1]; uint32_t bN = bshape.rank() >= 4 ? bshape[-4] : 1; uint32_t bC = bshape.rank() >= 3 ? bshape[-3] : 1; - uint32_t bH = bshape[-2]; - uint32_t bW = bshape[-1]; uint32_t NC = N * C; uint32_t HW = H * W; @@ -301,11 +295,11 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a uint32_t num_tensor_tiles = NC * Ht * Wt; - uint32_t bnc1 = (bN * bC == 1); + auto bnc1 = static_cast(bN * bC == 1); bool row_major = false; if (shard_spec.has_value()) { - row_major = shard_spec.value().orientation == ShardOrientation::ROW_MAJOR; + row_major = shard_spec->orientation == ShardOrientation::ROW_MAJOR; } auto [num_cores, all_cores, core_group_1, core_group_2, num_tiles_per_core_group_1, num_tiles_per_core_group_2] = tt::tt_metal::split_work_to_cores(compute_with_storage_grid_size, num_tensor_tiles, row_major); @@ -314,10 +308,10 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a if (shard_spec.has_value()) { uint32_t num_tiles_per_shard = 0; - num_tiles_per_shard = shard_spec.value().shape[0] * shard_spec.value().shape[1] / TILE_HW; + num_tiles_per_shard = shard_spec->shape[0] * shard_spec->shape[1] / TILE_HW; num_tiles_per_core_group_1 = num_tiles_per_shard; num_tiles_per_core_group_2 = 0; - all_cores = shard_spec.value().grid; + all_cores = shard_spec->grid; core_group_1 = all_cores; core_group_2 = CoreRangeSet(); } @@ -346,14 +340,21 @@ void BinaryDeviceOperation::BroadcastHeightAndWidthMultiCore::override_runtime_a } binary_reader_args[0] = src_buffer_a->address(); - binary_reader_args[1] = src_dram_buffer_b->address(); + + if (input_tensor_b.has_value()) { + binary_reader_args[1] = input_tensor_b->buffer()->address(); + } else { + class bfloat16 bfloat_scalar(*operation_attributes.scalar); + uint32_t packed_scalar = pack_two_bfloat16_into_uint32({bfloat_scalar, bfloat_scalar}); + binary_reader_args[1] = packed_scalar; + } binary_reader_args[2] = num_tensor_tiles_per_core; binary_reader_args[3] = HtWt; binary_reader_args[4] = num_tiles_read / HtWt * HtWt; binary_reader_args[5] = num_tiles_read % HtWt; binary_reader_args[6] = bnc1 ? 0 : num_tiles_read / HtWt; - bcast_kernel_args[2] = num_tensor_tiles_per_core; // Wt + bcast_kernel_args[2] = num_tensor_tiles_per_core; // Wt unary_writer_args[0] = dst_buffer->address(); unary_writer_args[1] = num_tensor_tiles_per_core; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp index 423e45487ad0..ac00c26f9a38 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_program_factory.cpp @@ -39,7 +39,7 @@ BinaryDeviceOperation ::BroadcastHeightMultiCore::create( auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); const auto ashape = a.get_legacy_shape(); - const auto bshape = b.get_legacy_shape(); + const auto bshape = b->get_legacy_shape(); uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; @@ -64,7 +64,7 @@ BinaryDeviceOperation ::BroadcastHeightMultiCore::create( tt_metal::Device* device = a.device(); tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b->get_dtype()); tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); @@ -84,7 +84,7 @@ BinaryDeviceOperation ::BroadcastHeightMultiCore::create( auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); auto src0_buffer = a.buffer(); - auto src1_buffer = b.buffer(); + auto src1_buffer = b->buffer(); auto dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); @@ -159,7 +159,7 @@ BinaryDeviceOperation ::BroadcastHeightMultiCore::create( 0, // 1 0, // 2 num_tensor_tiles_per_core, // 3 - b.buffer()->address(), // 4 + b->buffer()->address(), // 4 0, // 5 0, // 6 num_btensor_tiles, // 7 @@ -231,12 +231,12 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCore::override_runtime_argument uint32_t num_cores_total = num_cores_x * num_cores_y; auto src_dram_buffer_a = input_tensor_a.buffer(); - auto src_dram_buffer_b = input_tensor_b.buffer(); + auto src_dram_buffer_b = input_tensor_b->buffer(); auto dst_dram_buffer = output_tensor.buffer(); const auto ashape = input_tensor_a.get_legacy_shape(); - const auto bshape = input_tensor_b.get_legacy_shape(); + const auto bshape = input_tensor_b->get_legacy_shape(); uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_optimized_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_optimized_program_factory.cpp index 6c0860f3bb12..f28c4e01db1b 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_optimized_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_optimized_program_factory.cpp @@ -39,7 +39,7 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreShardedOptimized::create( auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); const auto ashape = a.get_legacy_shape(); - const auto bshape = b.get_legacy_shape(); + const auto bshape = b->get_legacy_shape(); uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; @@ -78,7 +78,7 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreShardedOptimized::create( ncores); tt::DataFormat act_df = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::DataFormat b_df = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat b_df = tt_metal::datatype_to_dataformat_converter(b->get_dtype()); tt::DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t input_tile_size = tt::tt_metal::detail::TileSize(act_df); @@ -141,7 +141,7 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreShardedOptimized::create( auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, src1_cb_config); auto src0_buffer = a.buffer(); - auto src1_buffer = b.buffer(); + auto src1_buffer = b->buffer(); auto dst_buffer = output.buffer(); bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = {(uint32_t)src0_cb_index, (uint32_t)src1_is_dram}; @@ -212,7 +212,7 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreShardedOptimized::create( binary_reader_kernel_id, core, { - b.buffer()->address(), // (0) src1_addr + b->buffer()->address(), // (0) src1_addr Ht, // (1) Ht Wt, // (2) Wt offset, // (3) read offset in1 @@ -275,7 +275,7 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCoreShardedOptimized::override_ uint32_t Wt = 0, Ht =0; const auto ashape = input_tensor_a.get_legacy_shape(); uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3]; - uint32_t bN = input_tensor_b.get_legacy_shape()[0]; + uint32_t bN = input_tensor_b->get_legacy_shape()[0]; uint32_t NC = N*C; if(a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED){ Wt = shard_spec.shape[1] / TILE_WIDTH; @@ -321,7 +321,7 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCoreShardedOptimized::override_ binary_reader_kernel_id, core, { - b.buffer()->address(), // (0) src1_addr + b->buffer()->address(), // (0) src1_addr Ht, // (1) Ht Wt, // (2) Wt offset, // (3) read offset in1 diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_program_factory.cpp index 602c23966920..c1ea1a028f2a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_height_multi_core_sharded_program_factory.cpp @@ -39,7 +39,7 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreSharded::create( auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); const auto ashape = a.get_legacy_shape(); - const auto bshape = b.get_legacy_shape(); + const auto bshape = b->get_legacy_shape(); uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; @@ -74,7 +74,7 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreSharded::create( TT_FATAL(out_shard_spec.num_cores() == ncores, "Output tensor should have same number of cores {} as input tensor {}", out_shard_spec.num_cores(), ncores); tt::DataFormat act_df = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::DataFormat b_df = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat b_df = tt_metal::datatype_to_dataformat_converter(b->get_dtype()); tt::DataFormat out_df = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t input_tile_size = tt::tt_metal::detail::TileSize(act_df); @@ -117,14 +117,14 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreSharded::create( .set_globally_allocated_address(*output.buffer()); auto out_cb = tt_metal::CreateCircularBuffer(program, all_cores, output_cb_config); - uint32_t num_input_tiles = (b.get_legacy_shape()[-1] * output.element_size() + TILE_HW - 1)/ TILE_HW; + uint32_t num_input_tiles = (b->get_legacy_shape()[-1] * output.element_size() + TILE_HW - 1)/ TILE_HW; uint32_t src1_cb_index = CB::c_in1; tt_metal::CircularBufferConfig src1_cb_config = tt_metal::CircularBufferConfig(num_input_tiles * input1_tile_size, {{src1_cb_index, b_df}}) .set_page_size(src1_cb_index, input1_tile_size); auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_cores, src1_cb_config); auto src0_buffer = a.buffer(); - auto src1_buffer = b.buffer(); + auto src1_buffer = b->buffer(); auto dst_buffer = output.buffer(); bool src1_is_dram = src1_buffer->buffer_type() == tt_metal::BufferType::DRAM ? 1 : 0; std::vector reader_compile_time_args = {(uint32_t)src0_cb_index, (uint32_t)src1_is_dram}; @@ -179,7 +179,7 @@ BinaryDeviceOperation::BroadcastHeightMultiCoreSharded::create( binary_reader_kernel_id, core, { - b.buffer()->address(), // 0 + b->buffer()->address(), // 0 Ht, // 1 Wt, // 2 offset, // 3 @@ -242,7 +242,7 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCoreSharded::override_runtime_a uint32_t Wt = 0, Ht =0; const auto ashape = input_tensor_a.get_legacy_shape(); uint32_t N = ashape[0], C = ashape[1], H = ashape[2], W = ashape[3]; - uint32_t bN = input_tensor_b.get_legacy_shape()[0]; + uint32_t bN = input_tensor_b->get_legacy_shape()[0]; uint32_t NC = N*C; if(a.memory_config().memory_layout == TensorMemoryLayout::BLOCK_SHARDED){ Wt = shard_spec.shape[1] / TILE_WIDTH; @@ -284,7 +284,7 @@ void BinaryDeviceOperation ::BroadcastHeightMultiCoreSharded::override_runtime_a binary_reader_kernel_id, core, { - b.buffer()->address(), // 0 + b->buffer()->address(), // 0 Ht, // 1 Wt, // 2 offset, // 3 diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp index 9b6d458581cf..82c930819097 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/broadcast_width_multi_core_program_factory.cpp @@ -38,7 +38,7 @@ BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOpe auto bcast_math = binary_op_type_to_bcast_op_math(operation_attributes.binary_op_type); const auto ashape = a.get_legacy_shape(); - const auto bshape = b.get_legacy_shape(); + const auto bshape = b->get_legacy_shape(); uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; @@ -63,7 +63,7 @@ BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOpe tt_metal::Device* device = a.device(); tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); - tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b->get_dtype()); tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); @@ -83,7 +83,7 @@ BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOpe auto cores = grid_to_cores(num_cores_total, num_cores_x, num_cores_y, row_major); auto src0_buffer = a.buffer(); - auto src1_buffer = b.buffer(); + auto src1_buffer = b->buffer(); auto dst_buffer = output.buffer(); TT_ASSERT(dst_buffer != nullptr, "Output buffer should be allocated on device!"); @@ -159,7 +159,7 @@ BinaryDeviceOperation::BroadcastWidthMultiCore::cached_program_t BinaryDeviceOpe 0, // 1 0, // 2 num_tensor_tiles_per_core, // 3 - b.buffer()->address(), // 4 + b->buffer()->address(), // 4 0, // 5 0, // 6 num_btensor_tiles, // 7 @@ -231,12 +231,12 @@ void BinaryDeviceOperation::BroadcastWidthMultiCore::override_runtime_arguments( uint32_t num_cores_total = num_cores_x * num_cores_y; auto src_dram_buffer_a = input_tensor_a.buffer(); - auto src_dram_buffer_b = input_tensor_b.buffer(); + auto src_dram_buffer_b = input_tensor_b->buffer(); auto dst_dram_buffer = output_tensor.buffer(); const auto ashape = input_tensor_a.get_legacy_shape(); - const auto bshape = input_tensor_b.get_legacy_shape(); + const auto bshape = input_tensor_b->get_legacy_shape(); uint32_t N = ashape.rank() >= 4 ? ashape[-4] : 1; uint32_t C = ashape.rank() >= 3 ? ashape[-3] : 1; uint32_t H = ashape[-2]; diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp index eb57eb345b96..bbd5370f363c 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/element_wise_multi_core_program_factory.cpp @@ -296,7 +296,7 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat tt::DataFormat src0_cb_data_format = tt_metal::datatype_to_dataformat_converter(a.get_dtype()); uint32_t src0_single_tile_size = tt_metal::detail::TileSize(src0_cb_data_format); - tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b.get_dtype()); + tt::DataFormat src1_cb_data_format = tt_metal::datatype_to_dataformat_converter(b->get_dtype()); uint32_t src1_single_tile_size = tt_metal::detail::TileSize(src1_cb_data_format); tt::DataFormat dst_cb_data_format = tt_metal::datatype_to_dataformat_converter(output.get_dtype()); uint32_t dst_single_tile_size = tt_metal::detail::TileSize(dst_cb_data_format); @@ -305,13 +305,13 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat tt::DataFormat interim_cb1_format = src1_cb_data_format; tt_metal::Buffer* src0_buffer = a.buffer(); - tt_metal::Buffer* src1_buffer = b.buffer(); + tt_metal::Buffer* src1_buffer = b->buffer(); tt_metal::Device* device = a.device(); std::optional shard_spec = std::nullopt; bool src0_sharded = a.memory_config().is_sharded(); - bool src1_sharded = b.memory_config().is_sharded(); + bool src1_sharded = b->memory_config().is_sharded(); bool out_sharded = output.memory_config().is_sharded(); auto compute_with_storage_grid_size = device->compute_with_storage_grid_size(); @@ -324,8 +324,8 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat shard_spec = a.shard_spec().value(); block_or_width_sharded = a.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; } else if (src1_sharded) { - shard_spec = b.shard_spec().value(); - block_or_width_sharded = b.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; + shard_spec = b->shard_spec().value(); + block_or_width_sharded = b->memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; } else if (out_sharded) { shard_spec = output.shard_spec().value(); block_or_width_sharded = output.memory_config().memory_layout != TensorMemoryLayout::HEIGHT_SHARDED; @@ -358,7 +358,7 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat tt_metal::CircularBufferConfig(num_input_tiles * src1_single_tile_size, {{src1_cb_index, src1_cb_data_format}}) .set_page_size(src1_cb_index, src1_single_tile_size); if (src1_sharded) { - cb_src1_config = cb_src1_config.set_globally_allocated_address(*b.buffer()); + cb_src1_config = cb_src1_config.set_globally_allocated_address(*b->buffer()); } auto cb_src1 = tt_metal::CreateCircularBuffer(program, all_device_cores, cb_src1_config); @@ -443,7 +443,7 @@ BinaryDeviceOperation::ElementWiseMultiCore::cached_program_t BinaryDeviceOperat set_eltwise_binary_runtime_args( program, a, - b, + *b, output, binary_reader_kernel_id, unary_writer_kernel_id, @@ -484,7 +484,7 @@ void BinaryDeviceOperation::ElementWiseMultiCore::override_runtime_arguments( set_eltwise_binary_runtime_args( cached_program.program, input_tensor_a, - input_tensor_b, + *input_tensor_b, output_tensor, shared_variables.binary_reader_kernel_id, shared_variables.unary_writer_kernel_id,