From 33855de8fb3022866ea4a8e1355135f7146029af Mon Sep 17 00:00:00 2001 From: Jiafan Date: Sat, 14 Dec 2024 01:55:04 +0000 Subject: [PATCH] [0] reinterpret_i32 output mismatch when disable log print --- kernels/gaudi2/reinterpret_fwd_i32.c | 17 +++ kernels/include/reinterpret_fwd_i32.h | 53 +++++++++ reinterpret_test/hpu_custom_reinterpret.cpp | 58 ++++++++++ reinterpret_test/kernel_test.py | 112 ++++++++++++++++++++ reinterpret_test/setup.py | 26 +++++ src/entry_points.cpp | 10 ++ src/entry_points.hpp | 1 + src/gaudi2_src/reinterpret_fwd_i32.cpp | 83 +++++++++++++++ src/gaudi2_src/reinterpret_fwd_i32.hpp | 26 +++++ 9 files changed, 386 insertions(+) create mode 100644 kernels/gaudi2/reinterpret_fwd_i32.c create mode 100644 kernels/include/reinterpret_fwd_i32.h create mode 100644 reinterpret_test/hpu_custom_reinterpret.cpp create mode 100755 reinterpret_test/kernel_test.py create mode 100644 reinterpret_test/setup.py create mode 100644 src/gaudi2_src/reinterpret_fwd_i32.cpp create mode 100644 src/gaudi2_src/reinterpret_fwd_i32.hpp diff --git a/kernels/gaudi2/reinterpret_fwd_i32.c b/kernels/gaudi2/reinterpret_fwd_i32.c new file mode 100644 index 0000000..2266398 --- /dev/null +++ b/kernels/gaudi2/reinterpret_fwd_i32.c @@ -0,0 +1,17 @@ +/********************************************************************** +Copyright (c) 2024 Habana Labs. + +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or +other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY +DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS +OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +********************************************************************/ + +#include "reinterpret_fwd_i32.h" diff --git a/kernels/include/reinterpret_fwd_i32.h b/kernels/include/reinterpret_fwd_i32.h new file mode 100644 index 0000000..e8806f6 --- /dev/null +++ b/kernels/include/reinterpret_fwd_i32.h @@ -0,0 +1,53 @@ +#include "kernel_config.h" + +void main(const tensor input, tensor output) { + printf("Hello from Reinterpret kernel\n"); + const int dim0 = 0; + const int dim1 = 1; + const int dim2 = 2; + const int dim3 = 3; + const int dim4 = 4; + + const int5 index_space_start = get_index_space_offset(); + const int5 index_space_end = get_index_space_size() + index_space_start; + + const int dim0Step = 64; // we consume 64 float elements at a time (float64) + const int dim0Start = index_space_start[dim0] * dim0Step; + const int dim0End = index_space_end[dim0] * dim0Step; + + const int dim1Step = 1; + const int dim1Start = index_space_start[dim1] * dim1Step; + const int dim1End = index_space_end[dim1] * dim1Step; + + const int dim2Step = 1; + const int dim2Start = index_space_start[dim2] * dim2Step; + const int dim2End = index_space_end[dim2] * dim2Step; + + const int dim3Step = 1; + const int dim3Start = index_space_start[dim3] * dim3Step; + const int dim3End = index_space_end[dim3] * dim3Step; + + const int dim4Step = 1; + const int dim4Start = index_space_start[dim4] * dim4Step; + const int dim4End = index_space_end[dim4] * dim4Step; + + int5 coords = {0, 0, 0, 0, 0}; + + for (int d0 = dim0Start; d0 < dim0End; d0 += dim0Step) { + coords[dim0] = d0; + for (int d1 = dim1Start; d1 < dim1End; d1 += dim1Step) { + coords[dim1] = d1; + for (int d2 = dim2Start; d2 < dim2End; d2 += dim2Step) { + coords[dim2] = d2; + for (int d3 = dim3Start; d3 < dim3End; d3 += dim3Step) { + coords[dim3] = d3; + for (int d4 = dim4Start; d4 < dim4End; d4 += dim4Step) { + coords[dim4] = d4; + // read the bytes as i32 and store in the output tensor. It's as easy as that + v_i32_st_tnsr(coords, output, v_i32_ld_tnsr_b(coords, input)); + } + } + } + } + } +} diff --git a/reinterpret_test/hpu_custom_reinterpret.cpp b/reinterpret_test/hpu_custom_reinterpret.cpp new file mode 100644 index 0000000..ae582ad --- /dev/null +++ b/reinterpret_test/hpu_custom_reinterpret.cpp @@ -0,0 +1,58 @@ +#include "hpu_custom_op.h" + +#include + +#include + + +bool register_custom_reinterpret() { + // inputs desc + habana::custom_op::InputDesc input_a_desc { + habana::custom_op::input_type::TENSOR, 0 + }; + std::vector inputs_desc { input_a_desc }; + + auto output_size_lambda = [](const at::Stack& inputs) -> std::vector { + return inputs[0].toTensor().sizes().vec(); // Output shape is same as input tensor shape + }; + + habana::custom_op::OutputDesc output_desc{ + 0, c10::ScalarType::Int, output_size_lambda}; // Output dtype will be set in execute function + std::vector outputs_desc{ + output_desc}; + // acctual register + REGISTER_CUSTOM_OP_ATTRIBUTES( + "custom_op::reinterpret_float", //schema name + "reinterpret_fwd_i32", // guid + inputs_desc, + outputs_desc, + nullptr); + std::cout << "cpp registered custom_op::reinterpret_float\n"; + return true; +} + +at::Tensor custom_reinterpret_execute(torch::Tensor input_a) +{ + // Registering the custom op, need to be called only once + static bool registered = register_custom_reinterpret(); + TORCH_CHECK(registered, "custom_reinterpret kernel not registered" ); + std::vector inputs{input_a}; + + // Get custom op descriptor from registry + auto op_desc = habana::custom_op::HabanaCustomOpDescriptor::getCustomOpDescriptor("custom_op::reinterpret_float"); + + + // Actual call for op execution + std::vector output = op_desc.execute(inputs); + + return output[0]; +} + +TORCH_LIBRARY(custom_op, m) { + m.def("reinterpret_float(Tensor self) -> Tensor"); +} +TORCH_LIBRARY_IMPL(custom_op, HPU, m) { + m.impl("reinterpret_float", custom_reinterpret_execute); +} + + diff --git a/reinterpret_test/kernel_test.py b/reinterpret_test/kernel_test.py new file mode 100755 index 0000000..d3172e5 --- /dev/null +++ b/reinterpret_test/kernel_test.py @@ -0,0 +1,112 @@ +#!/bin/env python3 +import pathlib +import torch, logging +import os +file_path = pathlib.Path(__file__).parent.resolve() +os.environ["GC_KERNEL_PATH"] += f":{file_path}/libcustom_tpc_perf_lib.so" + +from mpi4py import MPI +os.environ['MASTER_ADDR'] = 'localhost' # server with rank=0 (master) +os.environ['MASTER_PORT'] = '12355' +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +world_size = comm.Get_size() +os.environ['RANK'] = f"{rank}" +os.environ['WORLD_SIZE'] = f"{world_size}" +os.environ['LOCAL_RANK'] = f"{rank}" + +os.environ["PT_HPU_LAZY_MODE"] = "1" + + +print(f"world_size {world_size}") +print(f"rank {rank}") +import habana_frameworks.torch.core as htcore +import habana_frameworks.torch as htorch + +class TensorChecker: + def __init__(self, gold=None, device=None): + self.device = device + self.aggregate_check = torch.empty(size=(0,)) + self.gold = gold + + @property + def gold(self): + return self._gold + + @gold.setter + def gold(self, value): + self._gold = value if value is not None else torch.empty(size=(0,)) + self.init_aggregate_check() + + def init_aggregate_check(self): + self.aggregate_check = torch.ones_like(self._gold, dtype=torch.bool, device=self.device) + + def check(self, answer): + answer = answer.to(torch.device(self.device)) + self._gold = self._gold.to(torch.device(self.device)) + self.aggregate_check &= torch.eq(self._gold, answer) + + def passed(self): + return torch.all(self.aggregate_check) + + def failed(self): + return not self.passed() + +def log(msg): + print(f"Rank{rank}: {msg}", flush=True) + +def run(input_values): + tin = torch.tensor(input_values, dtype=torch.float32, device="hpu") + # tin = tin.unsqueeze(0).unsqueeze(0).unsqueeze(0).unsqueeze(0) + # print(f"tin.shape = {tin.shape}") + kernel_path = f"{file_path}/hpu_custom_reinterpret.cpython-310-x86_64-linux-gnu.so" + torch.ops.load_library(kernel_path) + + # If you un-comment this print: + # it will start working, but resutls are incorrect particularly, when running on multiple gaudis + # log(tin) + + tout = torch.ops.custom_op.reinterpret_float(tin) + + # If you uncomment this print: + # All the answers returned are zero + + # If you uncomment both prints, then we will pass, + # though we have seen some data mismatches on rare occasions + # log(tout) + # tout = tout.squeeze(0).squeeze(0).squeeze(0).squeeze(0) + # print(f"tout.shape = {tout.shape}") + + return tout + + + +def print_tensor_as_hex(tensor): + return [f"0x{value.item():08x}" for value in tensor.view(-1)] + # print(f"0x{value.item():08x}") + + +if __name__ == '__main__': + iterations = 10 + + # We will convert 1.0, 2,0 and 3.0 from float to raw binary value (reinterpreted_cast) + # And the expected values are listed in the gold tensor below + input_values = [1.0, 2.0, 3.0] + # 1.0 2.0 3.0 + gold = torch.tensor([0x3f800000, 0x40000000, 0x40400000], dtype=torch.int32) + checker = TensorChecker(gold, device="cpu") + answers = [] + for _ in range(iterations): + ans = run(input_values) + checker.check(ans) + # answers.append(ans) + + print("Answers:") + for i, v in enumerate(answers): + log(f" Iteration {i} : {print_tensor_as_hex(v)}") + + + if checker.failed(): + log(f"Mismatches detected on data during local checking on rank {rank} ") + else: + log(f"No mismatches detected on rank {rank}") diff --git a/reinterpret_test/setup.py b/reinterpret_test/setup.py new file mode 100644 index 0000000..758d129 --- /dev/null +++ b/reinterpret_test/setup.py @@ -0,0 +1,26 @@ +############################################################################### +# Copyright (C) 2020-2021 Habana Labs, Ltd. an Intel Company +############################################################################### + +from setuptools import setup +from torch.utils import cpp_extension +from habana_frameworks.torch.utils.lib_utils import get_include_dir, get_lib_dir +import os +import pybind11 + +torch_include_dir = get_include_dir() +torch_lib_dir = get_lib_dir() +habana_modules_directory = "/usr/include/habanalabs" +pybind_include_path = pybind11.get_include() + +setup(name='hpu_custom_reinterpret', + ext_modules=[cpp_extension.CppExtension('hpu_custom_reinterpret', ['hpu_custom_reinterpret.cpp'], + #language='c++', extra_compile_args=["-std=c++17"], + libraries=['habana_pytorch_plugin'], + library_dirs=[torch_lib_dir])], + include_dirs=[torch_include_dir, + habana_modules_directory, + pybind_include_path, + ], + cmdclass={'build_ext': cpp_extension.BuildExtension}) + diff --git a/src/entry_points.cpp b/src/entry_points.cpp index 86a782d..b6423c4 100644 --- a/src/entry_points.cpp +++ b/src/entry_points.cpp @@ -37,6 +37,7 @@ NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVE #include "add_f32_gaudi2.hpp" #include "relu_all_gaudi2.hpp" #include "user_lut_gaudi2.hpp" +#include "reinterpret_fwd_i32.hpp" #include "entry_points.hpp" #include @@ -141,6 +142,8 @@ tpc_lib_api::GlueCodeReturn GetKernelGuids( _IN_ tpc_lib_api::DeviceId ReluBwdBF16g2Instance.GetKernelName(guids[GAUDI2_KERNEL_RELU_BWD_BF16].name, ReluAllGaudi2::relu_bwd_bf16); UserLutGaudi2 userLutInstance; userLutInstance.GetKernelName(guids[GAUDI2_KERNEL_USER_LUT].name); + ReinterpretFwdI32 ReinterpretFwdI32Instance; + ReinterpretFwdI32Instance.GetKernelName(guids[GAUDI2_KERNEL_REINTERPRET_FWD_I32].name); } if (kernelCount != nullptr) @@ -441,6 +444,13 @@ InstantiateTpcKernel(_IN_ tpc_lib_api::HabanaKernelParams* params, return userLutInstance.GetGcDefinitions(params,instance); } + ReinterpretFwdI32 ReinterpretFwdI32Instance; + ReinterpretFwdI32Instance.GetKernelName(kernelName); + if (strcmp(params->guid.name, kernelName) == 0) + { + return ReinterpretFwdI32Instance.GetGcDefinitions(params,instance); + } + return tpc_lib_api::GLUE_NODE_NOT_FOUND; } diff --git a/src/entry_points.hpp b/src/entry_points.hpp index 7ec9f24..6f5bdf4 100644 --- a/src/entry_points.hpp +++ b/src/entry_points.hpp @@ -69,6 +69,7 @@ typedef enum GAUDI2_KERNEL_RELU_FWD_BF16, GAUDI2_KERNEL_RELU_BWD_BF16, GAUDI2_KERNEL_USER_LUT, + GAUDI2_KERNEL_REINTERPRET_FWD_I32, GAUDI2_KERNEL_MAX_EXAMPLE_KERNEL diff --git a/src/gaudi2_src/reinterpret_fwd_i32.cpp b/src/gaudi2_src/reinterpret_fwd_i32.cpp new file mode 100644 index 0000000..98f0c35 --- /dev/null +++ b/src/gaudi2_src/reinterpret_fwd_i32.cpp @@ -0,0 +1,83 @@ +#include +#include +#include "reinterpret_fwd_i32.hpp" // Include the header file for your kernel + +extern unsigned char _binary___reinterpret_fwd_i32_o_start; +extern unsigned char _binary___reinterpret_fwd_i32_o_end; + +tpc_lib_api::GlueCodeReturn ReinterpretFwdI32::GetKernelName( + char kernelName [tpc_lib_api::MAX_NODE_NAME]) +{ + strcpy(kernelName, "reinterpret_fwd_i32"); + return tpc_lib_api::GLUE_SUCCESS; +} + + +tpc_lib_api::GlueCodeReturn ReinterpretFwdI32::GetGcDefinitions( + tpc_lib_api::HabanaKernelParams* in_defs, + tpc_lib_api::HabanaKernelInstantiation* out_defs) +{ + // Validate correct amount of input tensors + if (in_defs->inputTensorNr != 1) + { + return tpc_lib_api::GLUE_INCOMPATIBLE_INPUT_COUNT; + } + // Validate correct amount of output tensors + if (in_defs->outputTensorNr != 1) + { + return tpc_lib_api::GLUE_INCOMPATIBLE_OUTPUT_COUNT; + } + // Validate input data type is float and output data type is int + if (in_defs->inputTensors[0].geometry.dataType != tpc_lib_api::DATA_F32 || + in_defs->outputTensors[0].geometry.dataType != tpc_lib_api::DATA_I32) + { + return tpc_lib_api::GLUE_INCOMPATIBLE_DATA_TYPE; + } + + // Define index space geometry based on the output tensor dimensions + // Assuming the kernel processes the tensor in 64-element chunks + int elementsInVec = 64; + uint64_t outputSizes[gcapi::MAX_TENSOR_DIM] = {0}; + memcpy(outputSizes, in_defs->inputTensors[0].geometry.maxSizes, sizeof(outputSizes)); + + // Round up to elementsInVec and divide by elementsInVec + unsigned depthIndex = (outputSizes[0]) / elementsInVec; + out_defs->indexSpaceRank = 5; + out_defs->indexSpaceGeometry[0] = depthIndex; + out_defs->indexSpaceGeometry[1] = outputSizes[1]; + out_defs->indexSpaceGeometry[2] = outputSizes[2]; + out_defs->indexSpaceGeometry[3] = outputSizes[3]; + out_defs->indexSpaceGeometry[4] = outputSizes[4]; + + // Define index space mapping for input and output tensors + // The mapping is direct since this kernel does not change the data layout + for (uint32_t i = 0; i < out_defs->indexSpaceRank; ++i) + { + out_defs->inputTensorAccessPattern[0].mapping[i].indexSpaceDim = i; + out_defs->inputTensorAccessPattern[0].mapping[i].a = 1; + out_defs->inputTensorAccessPattern[0].mapping[i].start_b = 0; + out_defs->inputTensorAccessPattern[0].mapping[i].end_b = 0; + + out_defs->outputTensorAccessPattern[0].mapping[i].indexSpaceDim = i; + out_defs->outputTensorAccessPattern[0].mapping[i].a = 1; + out_defs->outputTensorAccessPattern[0].mapping[i].start_b = 0; + out_defs->outputTensorAccessPattern[0].mapping[i].end_b = 0; + } + // Load the ISA binary into the descriptor + unsigned IsaSize = (&_binary___reinterpret_fwd_i32_o_end - &_binary___reinterpret_fwd_i32_o_start); + unsigned givenBinarySize = out_defs->kernel.elfSize; + out_defs->kernel.elfSize = IsaSize; + + if (givenBinarySize >= IsaSize) + { + memcpy(out_defs->kernel.kernelElf, + &_binary___reinterpret_fwd_i32_o_start, + IsaSize); + } + else + { + return tpc_lib_api::GLUE_INSUFFICIENT_ELF_BUFFER; + } + + return tpc_lib_api::GLUE_SUCCESS; +} diff --git a/src/gaudi2_src/reinterpret_fwd_i32.hpp b/src/gaudi2_src/reinterpret_fwd_i32.hpp new file mode 100644 index 0000000..420be1f --- /dev/null +++ b/src/gaudi2_src/reinterpret_fwd_i32.hpp @@ -0,0 +1,26 @@ +// reinterpret_fwd_i32.hpp +#ifndef REINTERPRET_FWD_I32_HPP +#define REINTERPRET_FWD_I32_HPP + +#include "gc_interface.h" +#include "tpc_kernel_lib_interface.h" + +class ReinterpretFwdI32 +{ +public: + ReinterpretFwdI32() {} + virtual ~ReinterpretFwdI32() {} + + virtual tpc_lib_api::GlueCodeReturn + GetGcDefinitions(tpc_lib_api::HabanaKernelParams* in_defs, + tpc_lib_api::HabanaKernelInstantiation* out_defs); + + virtual tpc_lib_api::GlueCodeReturn GetKernelName( + char kernelName [tpc_lib_api::MAX_NODE_NAME]); + +private: + ReinterpretFwdI32(const ReinterpretFwdI32& other) = delete; + ReinterpretFwdI32& operator=(const ReinterpretFwdI32& other) = delete; +}; + +#endif // REINTERPRET_FWD_I32_HPP