From 7d845f0d9853e9bdaff7f803486e01625230dd88 Mon Sep 17 00:00:00 2001 From: Gus Smith Date: Wed, 15 May 2019 13:34:30 -0700 Subject: [PATCH] [Datatypes] Custom datatypes (#2900) * Register and use custom datatypes in TVM This patch adds the ability to register and use a custom datatype from Python, using the `register_datatype` call. The datatype can then be passed as the `dtype` parameter using the syntax `dtype="custom[]bitsxlanes"`. * Removes extra file * Register custom datatypes with TVM; specify Cast and Add lowering This commit adds functionality for registering custom datatypes with TVM, and furthermore adding custom lowering functions to lower those custom datatypes. This commit only adds lowering for the Cast and Add ops; more ops will be added soon. Check out some custom datatype samples in my repository of samples: https://github.com/gussmith23/tvm-custom-datatype-samples * Register and lower casts from Python * Formatting * Fix include; was including too much * Add comment * Add DatatypeRegistered * Add storage size field to custom datatypes This field indicates the bitwidth of the opaque block of data into which instances of the datatype will be stored, when TVM compiles. For example, if I create a datatype with a storage size of 16, then - Constants of that datatype will be created as unsigned 16-bit ints - Calls to external functions taking that datatype will pass the data as unsigned 16-bit ints - External functions returning that datatype will be assumed to return unsigned 16-bit ints. * Change how lowering funcs (Cast and other ops) are named in registry tvm.datatypes.lower..cast.. becomes tvm.datatypes.lower..Cast.. And fixes some sloppy code around how the other ops were being formatted. * Update Python register_datatype to accept storage size * Oops, left out one cast->Cast change * Look up storage size when parsing `custom[typename]` When we encounter this type string in Python, it will be parsed into a Halide type object in C++. Some of my original code supported this parsing, but we now have to attach the storage type to the type (by setting the bits field). * Change how external calls for casting/other ops are done Firstly, we now use the storage size of the custom type when determining input/output types; e.g. a cast to a custom type with storage size 16 is seen as a call to an external function returning an opaque uint of size 16. Secondly, write a macro to handle the other ops. Originally I thought I could handle these at runtime, with a single `_register_op` global. I transitioned instead to using individual `_register_Add` etc. calls generated with a macro, but I don't remember why. * When encountering a custom type immediate, generate UIntImm * Translate custom types to LLVM type * Generate correct return type in Casts Originally I was assuming that the result type from casts was always a custom datatype, and so I was making the Call return a UInt type. * Use TVM-idiomatic recursion style in DatatypesLowerer This was actually a bug, I'm pretty sure; we wouldn't have recursed deep on any complex programs. As a result of making this change, I also uncovered another potential bug, where the datatypes lowering pass would attempt to lower a Load of a custom type. By commenting out the `Mutate_` for Load, I was able to stop the error from cropping up, but frankly, I'm not satisfied with the solution; how is it that we are able to run codegen when Loads of custom datatypes are present in the IR? I have not written any code, to my knowledge, that will support this. Perhaps Load does not care about the underlying datatype? * Use CHECK * Add comment about which Mutate_s are needed * Add comments * Add GetCustomDatatypeRegistered as an extern C function * Formatting, comments, casting * Change how datatype string is formatted * Use bits() instead of GetStorageSize Use bits() instead of GetStorageSize * Change comment * Add datatype.py * Change registered function name (datatypes->datatype) * Remove GetStorageSize * Format custom datatypes like any other datatype Specifically, we now print the bits and lanes after the `custom[...]` string. * Correctly implement datatype lowering in Python * Remove unneeded include * Make function naming consistent * Use CHECK instead of internal_assert * Rename macro * Formatting * Rename functions * Implement Cast lowering `_datatype_register_op` is now able to lower both binary ops and Casts. * Formatting * Formatting * Clang format, google style * Fix std::string/extern "C" warnings * Formatting * Formatting * Lower Allocates and Loads during datatype lowering This should ensure that there are no custom datatypes remaining once datatype lowering is done. This will allow us to remove the code in the LLVM codegen which deals with custom datatypes. * Revert additions to codegen_llvm.cc which are now unneeded * Pass cpplint on lower_datatypes.cc * Add clarifying comment * Remove datatype lowering registration funcs from C++ * Add CHECKs * Remove TODO * Remove all references to storage size * Move and rename function * Rename function * Remove done TODOs and other handled comments * Remove irrelevant Load code and comments * Comment out the IR node types I'm not sure about yet * Add bfloat16 datatype unittest * Fix MakeConstScalar MakeConstScalar for a custom datatype will now call out to a function which can be registered on a per-datatype basis. The function will take a double and return the equivalent value in the custom datatype format. Note that these code paths are not actually used or tested at the moment. I have not yet written an example which uses const scalars of a custom datatype. * Formatting * Change pass name * Allow users to register whatever lowering function they want Tianqi pointed out that users should be able to register whatever lowering function they want, and should not be constrained to registering lowering functions which just call out to external libraries. I still provide a function for making lowering functions which call out to external libraries, for convenience. * Add clarifying comment * Remove unneeded comment * Remove unneeded function * Rename file * Undo unnecessary change * Undo unnecessary change * Make naming consistent Rename "datatypes" to "custom datatypes" in most contexts. * Revert an artifact of old code * Fix build warnings, add TODO * Lint * Remove unnecessary use of extern C by separating decl and impl * Error checking * Remove TODO * Missed a name change * Lint * Python lint * Correctly format datatype * Move bfloat16 to 3rdparty * "custom_datatypes" --> "datatype" in most places I left the pass as "LowerCustomDatatypes" to indicate that we're not lowering anything other than custom datatypes. Otherwise, everything else has been changed. * Upgrade datatype unittest I used a float calculator to generate some real testcases for the unittest. * Separate public includes and private implementation Specifically, create cleaner decoupling between datatypes stuff in packed_func and the datatype registry implementation. * Formatting * Limit custom datatype codes to >128 * Add TODOs * Fix comment * Formatting * Clean up datatype unittest * Remove un-exported functions in public headers; UIntImm->FloatImm More places where I accidentally was using implementation-only functions in public headers. Additionally, store custom datatype immediates as FloatImms. A later change will add new lowering logic to lower these FloatImms to UIntImms. Plus formatting change. * Lint * Use FloatImm (not UIntImm) to hold immediates of custom datatypes This change switches from using UIntImm to FloatImm for storing immediates of custom datatypes. The value of the number is stored in a double, which should be enough precision for now, for most custom types we will explore in the immediate future. In line with this change, we change the datatype lowering so that FloatImms are lowered to UInts of the appropriate size. Originally, this was going to be done by allowing the user to register a double->uint__t conversion which would be called at compile time to convert the value from the FloatImm to a UInt and store it in a UIntImm. After discussions with Tianqi, we decided to take the simpler route, and lower FloatImms just as we lower all other ops: by replacing them with Call nodes. In this case, presumably the user will Call out to a conversion function in their datatype library. The justification for this decision is due to the functionality added in #1486. This pull request adds the ability to load LLVM bytecode in at compile time. This applies in our case as follows: 1. The user writes their custom datatype programs and registers their lowering functions in the same way we've been doing it so far. All operations over custom datatypes are lowered to Calls to the datatype library. 2. The user compiles their datatype library to LLVM bytecode. 3. At TVM compile time, the user loads the LLVM bytecode. Depending on how the datatype library is written, Clang should be able to perform constant folding over the custom datatype immediates, even if their conversions are done with calls to the library. Additionally adds test to test the FloatImm codepath. * Re-add a change I removed accidentally during rebase * Cleanup * Remove unnecessary TVM_DLLs * Add custom datatype utilities source file to Go runtime pack * Revert "Remove unnecessary TVM_DLLs" This reverts commit 4b742b99557fd3bf0ce6617f033c8b444b74eda4. * Mark bfloat code as TVM_DLL * Moves custom datatype runtime utilities to c_runtime_api.cc * Revert "Add custom datatype utilities source file to Go runtime pack" This reverts commit aecbcde0b2cc09a2693955b77037fe20f93b5bfd. * Move datatype parsing to its own function * Change comments * Remove unneeded function * Formatting * Formatting * Documentation * Add kCustomBegin, use it for checking for custom types * Documentation * Formatting * Move static definition to implementation * Remove comment * Decide toBeLowered before lowering arguments of Expr In the past, e.g. when lowering custom datatypes for an Add, we would lower a and b first, and then decide whether the resulting new Add needed to be lowered based on the (new) types of a and b. Now, instead, we need to check the types of a and b first (to see if they're custom types), and then lower them (so they'll become non-custom types), and then lower the new Add. * Revert "Move datatype parsing to its own function" This reverts commit d554a5881afcf69af1c070d882a7651022703a09. This broke parsing. Will figure this out later. There isn't a really clean way to separate this out given how the rest of the function is written. * Replace comment * Documentation * Remove comment and TVM_DLL * Better error messages * Remove artifact of rebase * Separate datatypes parsing to its own function * Add \returns * Comment changes; add TODO * Refactor tests --- 3rdparty/HalideIR | 2 +- 3rdparty/bfloat16/bfloat16.cc | 80 +++++++++ CMakeLists.txt | 4 + include/tvm/expr_operator.h | 7 + include/tvm/ir_pass.h | 11 ++ include/tvm/runtime/c_runtime_api.h | 2 + include/tvm/runtime/packed_func.h | 37 +++- python/tvm/__init__.py | 3 +- python/tvm/_ffi/runtime_ctypes.py | 14 +- python/tvm/datatype.py | 146 ++++++++++++++++ src/api/api_pass.cc | 1 + src/codegen/datatype/registry.cc | 108 ++++++++++++ src/codegen/datatype/registry.h | 162 ++++++++++++++++++ src/pass/lower_custom_datatypes.cc | 140 +++++++++++++++ src/runtime/c_runtime_api.cc | 46 +++++ .../test_custom_datatypes_mybfloat16.py | 150 ++++++++++++++++ 16 files changed, 908 insertions(+), 5 deletions(-) create mode 100644 3rdparty/bfloat16/bfloat16.cc create mode 100644 python/tvm/datatype.py create mode 100644 src/codegen/datatype/registry.cc create mode 100644 src/codegen/datatype/registry.h create mode 100644 src/pass/lower_custom_datatypes.cc create mode 100644 tests/python/unittest/test_custom_datatypes_mybfloat16.py diff --git a/3rdparty/HalideIR b/3rdparty/HalideIR index a768f2f06279..ec9585a5a5df 160000 --- a/3rdparty/HalideIR +++ b/3rdparty/HalideIR @@ -1 +1 @@ -Subproject commit a768f2f0627917659a4d7167eee3190469b9d164 +Subproject commit ec9585a5a5df3de91e8916ac2d27a4a509eac5fc diff --git a/3rdparty/bfloat16/bfloat16.cc b/3rdparty/bfloat16/bfloat16.cc new file mode 100644 index 000000000000..333b534afc08 --- /dev/null +++ b/3rdparty/bfloat16/bfloat16.cc @@ -0,0 +1,80 @@ +/* + Copyright (c) 2019 by Contributors + \file tvm/src/codegen/custom_datatypes/mybfloat16.cc + \brief Small bfloat16 library for use in unittests + + Code originally from Tensorflow; taken and simplified. Original license: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + ==============================================================================*/ + +#include +#include +#include + +void FloatToBFloat16(const float* src, uint16_t* dst, size_t size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (; size != 0; p += 2, q++, size--) { + *q = p[0]; + } +#else + for (; size != 0; p += 2, q++, size--) { + *q = p[1]; + } +#endif +} + +void BFloat16ToFloat(const uint16_t* src, float* dst, size_t size) { + const uint16_t* p = reinterpret_cast(src); + uint16_t* q = reinterpret_cast(dst); +#if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__ + for (; size != 0; p++, q += 2, size--) { + q[0] = *p; + q[1] = 0; + } +#else + for (; size != 0; p++, q += 2, size--) { + q[0] = 0; + q[1] = *p; + } +#endif +} + +void BFloat16Add(const uint16_t* a, const uint16_t* b, uint16_t* dst, + size_t size) { + float a_f, b_f; + BFloat16ToFloat(a, &a_f, 1); + BFloat16ToFloat(b, &b_f, 1); + float out_f = a_f + b_f; + FloatToBFloat16(&out_f, dst, 1); +} + +extern "C" { +TVM_DLL TVM_DLL uint16_t FloatToBFloat16_wrapper(float in) { + uint16_t out; + FloatToBFloat16(&in, &out, 1); + return out; +} + +TVM_DLL float BFloat16ToFloat_wrapper(uint16_t in) { + float out; + BFloat16ToFloat(&in, &out, 1); + return out; +} + +TVM_DLL uint16_t BFloat16Add_wrapper(uint16_t a, uint16_t b) { + uint16_t out; + BFloat16Add(&a, &b, &out, 1); + return out; +} +} diff --git a/CMakeLists.txt b/CMakeLists.txt index d3604f09256b..e1e457d31b52 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -122,6 +122,8 @@ file(GLOB_RECURSE RELAY_SRCS ) list(APPEND COMPILER_SRCS ${RELAY_SRCS}) +file(GLOB DATATYPE_SRCS src/codegen/datatype/*.cc) +list(APPEND COMPILER_SRCS ${DATATYPE_SRCS}) if(NOT MSVC) file(GLOB COMPILER_VERILOG_SRCS src/codegen/verilog/*.cc) @@ -151,6 +153,8 @@ if(NOT USE_RTTI) add_definitions(-DDMLC_ENABLE_RTTI=0) endif() +list(APPEND RUNTIME_SRCS 3rdparty/bfloat16/bfloat16.cc) + if(USE_RPC) message(STATUS "Build with RPC support...") file(GLOB RUNTIME_RPC_SRCS src/runtime/rpc/*.cc) diff --git a/include/tvm/expr_operator.h b/include/tvm/expr_operator.h index 4ef3effaf251..2e1348e00470 100644 --- a/include/tvm/expr_operator.h +++ b/include/tvm/expr_operator.h @@ -33,6 +33,7 @@ #include "ir.h" namespace tvm { + /*! * \brief Make a const value with certain data type. * \param t The target type. @@ -551,6 +552,12 @@ inline Expr MakeConstScalar(Type t, ValueType value) { if (t.is_int()) return ir::IntImm::make(t, static_cast(value)); if (t.is_uint()) return ir::UIntImm::make(t, static_cast(value)); if (t.is_float()) return ir::FloatImm::make(t, static_cast(value)); + // For now, we store const scalar values of custom datatypes within doubles; later, during the + // datatypes lowering pass, we will lower the value to its true representation in the format + // specified by the datatype. + // TODO(gus) when do we need to start worrying about doubles not being precise enough? + if (static_cast(t.code()) >= static_cast(kCustomBegin)) + return ir::FloatImm::make(t, static_cast(value)); LOG(FATAL) << "cannot make const for type " << t; return Expr(); } diff --git a/include/tvm/ir_pass.h b/include/tvm/ir_pass.h index 20b56e0676eb..5ef4dc4ed9d7 100644 --- a/include/tvm/ir_pass.h +++ b/include/tvm/ir_pass.h @@ -500,6 +500,17 @@ LoweredFunc PointerValueTypeRewrite(LoweredFunc f); */ LoweredFunc LowerIntrin(LoweredFunc f, const std::string& target); +/*! + * \brief Lower custom datatypes. + * + * See tvm::datatypes::Registry for more information on adding custom datatypes. + * + * \param f The device function to be lowered. + * \param target The target device. + * \return Transformed function. + */ +LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target); + /*! * \brief Verify if memory accesses are legal for a specific target device type. * diff --git a/include/tvm/runtime/c_runtime_api.h b/include/tvm/runtime/c_runtime_api.h index f992e87ad100..ee3542f90255 100644 --- a/include/tvm/runtime/c_runtime_api.h +++ b/include/tvm/runtime/c_runtime_api.h @@ -114,6 +114,8 @@ typedef enum { // The following section of code is used for non-reserved types. kExtReserveEnd = 64U, kExtEnd = 128U, + // The rest of the space is used for custom, user-supplied datatypes + kCustomBegin = 128U, } TVMTypeCode; /*! diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index 9fcefcbbe4b1..82b3dd469541 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -60,6 +60,29 @@ namespace tvm { class Integer; namespace runtime { + +/*! + * \brief Runtime utility for getting custom type name from code + * \param type_code Custom type code + * \return Custom type name + */ +TVM_DLL std::string GetCustomTypeName(uint8_t type_code); + +/*! + * \brief Runtime utility for checking whether custom type is registered + * \param type_code Custom type code + * \return Bool representing whether type is registered + */ +TVM_DLL bool GetCustomTypeRegistered(uint8_t type_code); + +/*! + * \brief Runtime utility for parsing string of the form "custom[]" + * \param s String to parse + * \param scan pointer to parsing pointer, which is scanning across s + * \return type code of custom type parsed + */ +TVM_DLL uint8_t ParseCustomDatatype(const std::string& s, const char** scan); + // forward declarations class TVMArgs; class TVMArgValue; @@ -939,7 +962,11 @@ inline std::ostream& operator<<(std::ostream& os, TVMType t) { // NOLINT(*) if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { os << "bool"; return os; } - os << TypeCode2Str(t.code); + if (GetCustomTypeRegistered(t.code)) { + os << "custom[" << GetCustomTypeName(t.code) << "]"; + } else { + os << TypeCode2Str(t.code); + } if (t.code == kHandle) return os; os << static_cast(t.bits); if (t.lanes != 1) { @@ -960,7 +987,11 @@ inline std::string TVMType2String(TVMType t) { if (t.bits == 1 && t.lanes == 1 && t.code == kDLUInt) { return "bool"; } - repr += TypeCode2Str(t.code); + if (GetCustomTypeRegistered(t.code)) { + repr += "custom[" + GetCustomTypeName(t.code) + "]"; + } else { + repr += TypeCode2Str(t.code); + } if (t.code == kHandle) return repr; repr += std::to_string(static_cast(t.bits)); if (t.lanes != 1) { @@ -994,6 +1025,8 @@ inline TVMType String2TVMType(std::string s) { t.bits = 1; t.lanes = 1; return t; + } else if (s.substr(0, 6) == "custom") { + t.code = ParseCustomDatatype(s, &scan); } else { scan = s.c_str(); LOG(FATAL) << "unknown type " << s; diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index ce6f0602a572..5765eed0ad8b 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -38,12 +38,13 @@ from . import hybrid from . import testing from . import error +from . import datatype from . import ndarray as nd from .ndarray import context, cpu, gpu, opencl, cl, vulkan, metal, mtl from .ndarray import vpi, rocm, opengl, ext_dev -from ._ffi.runtime_ctypes import TypeCode +from ._ffi.runtime_ctypes import TypeCode, TVMType from ._ffi.ndarray import TVMContext from ._ffi.function import Function from ._ffi.base import TVMError, __version__ diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py index 4ede33a63936..72cff1a10ead 100644 --- a/python/tvm/_ffi/runtime_ctypes.py +++ b/python/tvm/_ffi/runtime_ctypes.py @@ -91,6 +91,13 @@ def __init__(self, type_str): self.type_code = 4 bits = 64 head = "" + elif head.startswith("custom"): + low, high = head.find('['), head.find(']') + if not low or not high or low >= high: + raise ValueError("Badly formatted custom type string %s" % type_str) + type_name = head[low + 1:high] + self.type_code = _api_internal._datatype_get_type_code(type_name) + head = head[high+1:] else: raise ValueError("Do not know how to handle type %s" % type_str) bits = int(head) if head else bits @@ -100,7 +107,12 @@ def __init__(self, type_str): def __repr__(self): if self.bits == 1 and self.lanes == 1: return "bool" - x = "%s%d" % (TVMType.CODE2STR[self.type_code], self.bits) + if self.type_code in TVMType.CODE2STR: + type_name = TVMType.CODE2STR[self.type_code] + else: + type_name = "custom[%s]" % \ + _api_internal._datatype_get_type_name(self.type_code) + x = "%s%d" % (type_name, self.bits) if self.lanes != 1: x += "x%d" % self.lanes return x diff --git a/python/tvm/datatype.py b/python/tvm/datatype.py new file mode 100644 index 000000000000..df3e3a62a510 --- /dev/null +++ b/python/tvm/datatype.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Custom datatype functionality""" +from __future__ import absolute_import as _abs + +from ._ffi.function import register_func as _register_func +from . import make as _make +from .api import convert +from .expr import Call as _Call, Cast as _Cast, FloatImm as _FloatImm +from ._ffi.runtime_ctypes import TVMType as _TVMType +from . import _api_internal + + +def register(type_name, type_code): + """Register a custom datatype with the given type name and type code + Currently, the type code is manually allocated by the user, and the + user must ensure that no two custom types share the same code. + Generally, this should be straightforward, as the user will be + manually registering all of their custom types. + + Parameters + ---------- + type_name : str + The name of the custom datatype + + type_code : int + The type's code, which should be >= kCustomBegin + """ + _api_internal._datatype_register(type_name, type_code) + + +def get_type_name(type_code): + """Get the type name from the type code + + Parameters + ---------- + type_code : int + The type code + """ + return _api_internal._datatype_get_type_name(type_code) + + +def get_type_code(type_name): + """Get the type code from the type name + + Parameters + ---------- + type_name : str + The type name + """ + return _api_internal._datatype_get_type_code(type_name) + + +def get_type_registered(type_code): + """Get a boolean representing whether the type is registered + + Parameters + ---------- + type_code: int + The type code + """ + return _api_internal._datatype_get_type_registered(type_code) + + +def register_op(lower_func, op_name, target, type_name, src_type_name=None): + """Register an external function which computes the given op. + + Currently, this will only work with Casts and binary expressions + whose arguments are named `a` and `b`. + TODO(gus) figure out what other special cases must be handled by + looking through expr.py. + + Parameters + ---------- + lower_func : function + The lowering function to call. See create_lower_func. + + op_name : str + The name of the operation which the function computes, given by its + Halide::Internal class name (e.g. Add, LE, Cast). + + target : str + The name of codegen target. + + type_name : str + The name of the custom datatype, e.g. posit (but not custom[posit]8). + + src_type_name : str + If op_name is "Cast", then this should be set to the source datatype of + the argument to the Cast. If op_name is not "Cast", this is unused. + """ + + if op_name == "Cast": + assert src_type_name is not None + lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ + + type_name + "." + src_type_name + else: + lower_func_name = "tvm.datatype.lower." + target + "." + op_name + "." \ + + type_name + _register_func(lower_func_name, lower_func) + + +def create_lower_func(extern_func_name): + """Returns a function which lowers an operation to a function call. + + Parameters + ---------- + extern_func_name : str + The name of the extern "C" function to lower to + """ + + def lower(op): + """ + Takes an op---either a Cast or a binary op (e.g. an Add) and returns a + call to the specified external function, passing the op's argument + (Cast) or arguments (a binary op). The return type of the call depends + on the type of the op: if it is a custom type, then a uint of the same + width as the custom type is returned. Otherwise, the type is + unchanged.""" + dtype = op.dtype + t = _TVMType(dtype) + if get_type_registered(t.type_code): + dtype = "uint" + str(t.bits) + if t.lanes > 1: + dtype += "x" + str(t.lanes) + if isinstance(op, (_Cast, _FloatImm)): + return _make.Call(dtype, extern_func_name, convert([op.value]), + _Call.Extern, None, 0) + return _make.Call(dtype, extern_func_name, convert([op.a, op.b]), + _Call.Extern, None, 0) + + return lower diff --git a/src/api/api_pass.cc b/src/api/api_pass.cc index 6195aac1b93f..d6c92aee94d1 100644 --- a/src/api/api_pass.cc +++ b/src/api/api_pass.cc @@ -151,6 +151,7 @@ REGISTER_PASS(LowerThreadAllreduce); REGISTER_PASS(LowerWarpMemory); REGISTER_PASS(RemapThreadAxis); REGISTER_PASS(LowerIntrin); +REGISTER_PASS(LowerCustomDatatypes); REGISTER_PASS(LowerTVMBuiltin); REGISTER_PASS(CombineContextCall); REGISTER_PASS(VerifyMemory); diff --git a/src/codegen/datatype/registry.cc b/src/codegen/datatype/registry.cc new file mode 100644 index 000000000000..28cc58204e8d --- /dev/null +++ b/src/codegen/datatype/registry.cc @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "registry.h" +#include + +namespace tvm { +namespace datatype { + +TVM_REGISTER_GLOBAL("_datatype_register").set_body([](TVMArgs args, TVMRetValue* ret) { + datatype::Registry::Global()->Register(args[0], static_cast(args[1].operator int())); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_code").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = datatype::Registry::Global()->GetTypeCode(args[0]); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_name").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeName(args[0].operator int()); +}); + +TVM_REGISTER_GLOBAL("_datatype_get_type_registered").set_body([](TVMArgs args, TVMRetValue* ret) { + *ret = Registry::Global()->GetTypeRegistered(args[0].operator int()); +}); + +Registry* Registry::Global() { + static Registry inst; + return &inst; +} + +void Registry::Register(const std::string& type_name, uint8_t type_code) { + CHECK(type_code >= kCustomBegin) << "Please choose a type code >= kCustomBegin for custom types"; + code_to_name_[type_code] = type_name; + name_to_code_[type_name] = type_code; +} + +uint8_t Registry::GetTypeCode(const std::string& type_name) { + CHECK(name_to_code_.find(type_name) != name_to_code_.end()) + << "Type name " << type_name << " not registered"; + return name_to_code_[type_name]; +} + +std::string Registry::GetTypeName(uint8_t type_code) { + CHECK(code_to_name_.find(type_code) != code_to_name_.end()) + << "Type code " << static_cast(type_code) << " not registered"; + return code_to_name_[type_code]; +} + +const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code, + uint8_t src_type_code) { + std::ostringstream ss; + ss << "tvm.datatype.lower."; + ss << target << "."; + ss << "Cast" + << "."; + + if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { + ss << datatype::Registry::Global()->GetTypeName(type_code); + } else { + ss << runtime::TypeCode2Str(type_code); + } + + ss << "."; + + if (datatype::Registry::Global()->GetTypeRegistered(src_type_code)) { + ss << datatype::Registry::Global()->GetTypeName(src_type_code); + } else { + ss << runtime::TypeCode2Str(src_type_code); + } + + return runtime::Registry::Get(ss.str()); +} + +const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code) { + std::ostringstream ss; + ss << "tvm.datatype.lower."; + ss << target; + ss << ".FloatImm."; + ss << datatype::Registry::Global()->GetTypeName(type_code); + return runtime::Registry::Get(ss.str()); +} + +uint64_t ConvertConstScalar(uint8_t type_code, double value) { + std::ostringstream ss; + ss << "tvm.datatype.convertconstscalar.float."; + ss << datatype::Registry::Global()->GetTypeName(type_code); + auto make_const_scalar_func = runtime::Registry::Get(ss.str()); + return (*make_const_scalar_func)(value).operator uint64_t(); +} + +} // namespace datatype +} // namespace tvm diff --git a/src/codegen/datatype/registry.h b/src/codegen/datatype/registry.h new file mode 100644 index 000000000000..d2e615765a18 --- /dev/null +++ b/src/codegen/datatype/registry.h @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_CODEGEN_DATATYPE_REGISTRY_H_ +#define TVM_CODEGEN_DATATYPE_REGISTRY_H_ + +#include +#include +#include +#include + +namespace tvm { +namespace datatype { + +/*! + * \brief Registry for custom datatypes. + * + * Adding custom datatypes currently requires two steps: + * 1. Register the datatype with the registry via a call to + * datatype::Registry::Register. This can also be done in Python + * directly---see the TVM globals registered in the corresponding .cc file. + * Currently, user should manually choose a type name and a type code, + * ensuring that neither conflict with existing types. + * 2. Use TVM_REGISTER_GLOBAL to register the lowering functions needed to + * lower the custom datatype. In general, these will look like: + * For Casts: tvm.datatype.lower..Cast.. + * Example: tvm.datatype.lower.llvm.Cast.myfloat.float for a Cast from + * float to myfloat. + * For other ops: tvm.datatype.lower... + * Examples: tvm.datatype.lower.llvm.Add.myfloat + * tvm.datatype.lower.llvm.FloatImm.posit + */ +class Registry { + public: + /*! + * \brief Get the global custom datatype registry singleton + */ + static Registry* Global(); + + /*! + * \brief Register custom datatype + * Register a custom datatype with the given type name and type code. Currently, the type code is + * manually allocated by the user, and the user must ensure that no two custom types share the + * same code. Generally, this should be straightforward, as the user will be manually registering + * all of their custom types. + * \param type_name The name of the type, e.g. "bfloat" + * \param type_code The type code, which should be greater than TVMTypeCode::kExtEnd + */ + void Register(const std::string& type_name, uint8_t type_code); + + /*! + * \brief Get type code from type name + * \param type_name The type name + * \return The type code + */ + uint8_t GetTypeCode(const std::string &type_name); + + /*! + * \brief Get type name from type code + * \param type_code The type code + * \return The type name + */ + std::string GetTypeName(uint8_t type_code); + + /*! + * \brief Get bool representing whether type is registered, given the type code + * \param type_code The type code + * \return bool representing whether the type is registered + */ + inline bool GetTypeRegistered(uint8_t type_code) { + return code_to_name_.find(type_code) != code_to_name_.end(); + } + + /*! + * \brief Get bool representing whether type is registered, given the type name + * \param type_name The type name + * \return bool representing whether the type is registered + */ + inline bool GetTypeRegistered(std::string type_name) { + return name_to_code_.find(type_name) != name_to_code_.end(); + } + + private: + // TODO(gus) is there a typedef for the code? + std::unordered_map code_to_name_; + std::unordered_map name_to_code_; +}; + +/*! + * \brief Convert scalar value to a custom datatype format + * \param type_code The custom datatype to convert to, specified by type code + * \param value The floating point value to convert + * \return The value, encoded in the bits of a uint64_t + */ +uint64_t ConvertConstScalar(uint8_t type_code, double value); + +/*! + * \brief Get lowering function for Cast ops + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype being cast to + * \param src_type_code The datatype being cast from + * \return Lowering function for Cast ops for the provided target, type, and source type + */ +const runtime::PackedFunc* GetCastLowerFunc(const std::string& target, uint8_t type_code, + uint8_t src_type_code); + +/*! + * \brief Get lowering function for FloatImms + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype of the FloatImm + * \return Lowering function for FloatImms for the provided target and type + */ +const runtime::PackedFunc* GetFloatImmLowerFunc(const std::string& target, uint8_t type_code); + +/*! + * \brief Get lowering function for other ops + * \param target The target we are lowering to, e.g. "llvm" + * \param type_code The datatype of the op + * \return Lowering function for other ops for the provided target and type + */ +#define DEFINE_GET_LOWER_FUNC_(OP) \ + inline const runtime::PackedFunc* Get##OP##LowerFunc(const std::string& target, \ + uint8_t type_code) { \ + return runtime::Registry::Get("tvm.datatype.lower." + target + "." #OP "." + \ + datatype::Registry::Global()->GetTypeName(type_code)); \ + } + +DEFINE_GET_LOWER_FUNC_(Add) +DEFINE_GET_LOWER_FUNC_(Sub) +DEFINE_GET_LOWER_FUNC_(Mul) +DEFINE_GET_LOWER_FUNC_(Div) +DEFINE_GET_LOWER_FUNC_(Mod) +DEFINE_GET_LOWER_FUNC_(Min) +DEFINE_GET_LOWER_FUNC_(Max) +DEFINE_GET_LOWER_FUNC_(EQ) +DEFINE_GET_LOWER_FUNC_(NE) +DEFINE_GET_LOWER_FUNC_(LT) +DEFINE_GET_LOWER_FUNC_(LE) +DEFINE_GET_LOWER_FUNC_(GT) +DEFINE_GET_LOWER_FUNC_(GE) +// Later changes may need to add more lowering functions as we support workloads with more ops. + +} // namespace datatype +} // namespace tvm + +#endif // TVM_CODEGEN_DATATYPE_REGISTRY_H_ diff --git a/src/pass/lower_custom_datatypes.cc b/src/pass/lower_custom_datatypes.cc new file mode 100644 index 000000000000..7598ef49eee0 --- /dev/null +++ b/src/pass/lower_custom_datatypes.cc @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * Copyright (c) 2019 by Contributors + * \file tvm/src/pass/lower_custom_datatypes.cc + * \brief Pass for lowering custom datatypes + */ + +#include +#include +#include +#include "../codegen/datatype/registry.h" + +namespace tvm { +namespace ir { + +/*! + * \brief Helper mutator to implement lowering of custom datatypes. + * + * Lowering datatypes works as follows: for every expression containing a custom + * datatype, we search for a global (registered by the implementer of the custom + * datatype) for lowering this type of expression, and uses it to lower the + * expression. + */ +class CustomDatatypesLowerer : public IRMutator { + public: + explicit CustomDatatypesLowerer(const std::string& target) : target_(target) {} + + inline Expr Mutate_(const Cast* op, const Expr& e) final { + auto type_code = op->type.code(); + auto src_type_code = op->value.type().code(); + // If either datatype is a registered custom datatype, we must lower. + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code) || + datatype::Registry::Global()->GetTypeRegistered(src_type_code); + Expr expr = IRMutator::Mutate_(op, e); + op = expr.as(); + if (toBeLowered) { + auto lower = datatype::GetCastLowerFunc(target_, type_code, src_type_code); + CHECK(lower) << "Cast lowering function for target " << target_ << " destination type " + << static_cast(type_code) << " source type " + << static_cast(src_type_code) << " not found"; + return (*lower)(expr); + } + return expr; + } + + inline Expr Mutate_(const FloatImm* imm, const Expr& e) final { + auto type_code = imm->type.code(); + if (datatype::Registry::Global()->GetTypeRegistered(type_code)) { + auto lower = datatype::GetFloatImmLowerFunc(target_, type_code); + CHECK(lower) << "FloatImm lowering function for target " << target_ << " type " + << static_cast(type_code) << " not found"; + return (*lower)(e); + } + return e; + } + + inline Stmt Mutate_(const Allocate* allocate, const Stmt& s) final { + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(allocate->type.code()); + Stmt stmt = IRMutator::Mutate_(allocate, s); + allocate = stmt.as(); + + if (toBeLowered) { + auto new_allocate_type = UInt(allocate->type.bits(), allocate->type.lanes()); + return Allocate::make(allocate->buffer_var, new_allocate_type, allocate->extents, + allocate->condition, allocate->body, allocate->new_expr, + allocate->free_function); + } + return stmt; + } + + inline Expr Mutate_(const Load* load, const Expr& e) final { + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(load->type.code()); + Expr expr = IRMutator::Mutate_(load, e); + load = expr.as(); + if (toBeLowered) { + auto new_load_type = UInt(load->type.bits()); + return Load::make(new_load_type, load->buffer_var, load->index, load->predicate); + } + return expr; + } + +#define DEFINE_MUTATE__(OP) \ + inline Expr Mutate_(const OP* op, const Expr& e) final { \ + auto type_code = op->type.code(); \ + bool toBeLowered = datatype::Registry::Global()->GetTypeRegistered(type_code); \ + Expr expr = IRMutator::Mutate_(op, e); \ + op = expr.as(); \ + if (toBeLowered) { \ + auto lower = datatype::Get##OP##LowerFunc(target_, type_code); \ + CHECK(lower) << #OP " lowering function for target " << target_ << " type " \ + << static_cast(type_code) << " not found"; \ + return (*lower)(expr); \ + } \ + return expr; \ + } + + DEFINE_MUTATE__(Add) + DEFINE_MUTATE__(Sub) + DEFINE_MUTATE__(Mul) + DEFINE_MUTATE__(Div) + DEFINE_MUTATE__(Mod) + DEFINE_MUTATE__(Min) + DEFINE_MUTATE__(Max) + DEFINE_MUTATE__(EQ) + DEFINE_MUTATE__(NE) + DEFINE_MUTATE__(LT) + DEFINE_MUTATE__(LE) + DEFINE_MUTATE__(GT) + DEFINE_MUTATE__(GE) + // Later changes may need to add more mutate functions as we support workloads with more ops. + + private: + std::string target_; +}; + +LoweredFunc LowerCustomDatatypes(LoweredFunc f, const std::string& target) { + auto n = make_node(*f.operator->()); + n->body = CustomDatatypesLowerer(target).Mutate(n->body); + return LoweredFunc(n); +} + +} // namespace ir +} // namespace tvm diff --git a/src/runtime/c_runtime_api.cc b/src/runtime/c_runtime_api.cc index 59cdb7f0a467..20793b4618b3 100644 --- a/src/runtime/c_runtime_api.cc +++ b/src/runtime/c_runtime_api.cc @@ -45,6 +45,52 @@ namespace tvm { namespace runtime { +std::string GetCustomTypeName(uint8_t type_code) { + auto f = tvm::runtime::Registry::Get("_datatype_get_type_name"); + CHECK(f) << "Function _datatype_get_type_name not found"; + return (*f)(type_code).operator std::string(); +} + +uint8_t GetCustomTypeCode(const std::string& type_name) { + auto f = tvm::runtime::Registry::Get("_datatype_get_type_code"); + CHECK(f) << "Function _datatype_get_type_code not found"; + return (*f)(type_name).operator int(); +} + +bool GetCustomTypeRegistered(uint8_t type_code) { + auto f = tvm::runtime::Registry::Get("_datatype_get_type_registered"); + CHECK(f) << "Function _datatype_get_type_registered not found"; + return (*f)(type_code).operator bool(); +} + +uint8_t ParseCustomDatatype(const std::string& s, const char** scan) { + CHECK(s.substr(0, 6) == "custom") << "Not a valid custom datatype string"; + + auto tmp = s.c_str(); + + CHECK(s.c_str() == tmp); + *scan = s.c_str() + 6; + CHECK(s.c_str() == tmp); + if (**scan != '[') LOG(FATAL) << "expected opening brace after 'custom' type in" << s; + CHECK(s.c_str() == tmp); + *scan += 1; + CHECK(s.c_str() == tmp); + size_t custom_name_len = 0; + CHECK(s.c_str() == tmp); + while (*scan + custom_name_len <= s.c_str() + s.length() && *(*scan + custom_name_len) != ']') + ++custom_name_len; + CHECK(s.c_str() == tmp); + if (*(*scan + custom_name_len) != ']') + LOG(FATAL) << "expected closing brace after 'custom' type in" << s; + CHECK(s.c_str() == tmp); + *scan += custom_name_len + 1; + CHECK(s.c_str() == tmp); + + auto type_name = s.substr(7, custom_name_len); + CHECK(s.c_str() == tmp); + return GetCustomTypeCode(type_name); +} + class DeviceAPIManager { public: static const int kMaxDeviceAPI = 32; diff --git a/tests/python/unittest/test_custom_datatypes_mybfloat16.py b/tests/python/unittest/test_custom_datatypes_mybfloat16.py new file mode 100644 index 000000000000..99c6cf5f268b --- /dev/null +++ b/tests/python/unittest/test_custom_datatypes_mybfloat16.py @@ -0,0 +1,150 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import tvm +from ctypes import * +import topi +import tvm.ir_pass as ir_pass +import numpy as np + +tgt = "llvm" + + +def setup(): + # You must first load the library containing the datatype implementation. + # In this case, we have built the test functions used below right into TVM. + # CDLL("libmybfloat16.so", RTLD_GLOBAL) + + tvm.datatype.register("bfloat", 129) + + tvm.datatype.register_op( + tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "Cast", + "llvm", "bfloat", "float") + tvm.datatype.register_op( + tvm.datatype.create_lower_func("BFloat16ToFloat_wrapper"), "Cast", + "llvm", "float", "bfloat") + tvm.datatype.register_op( + tvm.datatype.create_lower_func("BFloat16Add_wrapper"), "Add", "llvm", + "bfloat") + tvm.datatype.register_op( + tvm.datatype.create_lower_func("FloatToBFloat16_wrapper"), "FloatImm", + "llvm", "bfloat") + +def lower_datatypes_and_build(schedule, args): + """Create schedule and lower, manually lowering datatypes. + + Once datatype lowering is integrated directly into TVM's lower/build + process, we won't need to do this manually. + TODO(gus) integrate datatype lowering into build process; change this test""" + flist = tvm.lower(schedule, args) + flist = [flist] + flist = [ir_pass.LowerCustomDatatypes(func, tgt) for func in flist] + return tvm.build(flist[0], target=tgt) + +def test_bfloat_add_and_cast_1(): + X = tvm.placeholder((3, ), name="X") + Y = tvm.placeholder((3, ), name="Y") + Z = topi.cast( + topi.cast(X, dtype="custom[bfloat]16") + + topi.cast(Y, dtype="custom[bfloat]16"), + dtype="float") + + s = tvm.create_schedule([Z.op]) + built_cast = lower_datatypes_and_build(s, [X,Y,Z]) + + ctx = tvm.context(tgt, 0) + + # Used float32 calculator at http://www.weitz.de/ieee/. Generated float32s + # with at most 7-bit mantissas which, when added, produce a result with at + # most 7-bit mantissas. This is to ensure there are no errors due to + # float32->bfloat16 conversions. + x = tvm.nd.array( + np.array([4.4103796E-32, 14942208.0, 1.78125]).astype("float32"), + ctx=ctx) + y = tvm.nd.array( + np.array([-3.330669E-14, 19660800.0, 2.25]).astype("float32"), ctx=ctx) + z_expected = np.array([-3.330669E-14, 34603008.0, + 4.03125]).astype("float32") + z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx) + + built_cast(x, y, z) + + assert np.array_equal(z_expected, z.asnumpy()) + + +def test_bfloat_add_and_cast_2(): + X = tvm.placeholder((3, ), name="X") + Y = tvm.placeholder((3, ), name="Y") + Z = topi.cast( + topi.cast(X, dtype="custom[bfloat]16") + + topi.cast(Y, dtype="custom[bfloat]16"), + dtype="float") + + s = tvm.create_schedule([Z.op]) + built_cast = lower_datatypes_and_build(s, [X,Y,Z]) + + ctx = tvm.context(tgt, 0) + + # Used float32 calculator at http://www.weitz.de/ieee/. Generated + # unconstrained float32s for the operands and copied them in to x and y. + # Then, to simulate float32->bfloat16 conversion implemented by the mybfloat + # library, I cut off all but 7 bits of the mantissa. I then added the + # numbers. To simulate bfloat16 add implemented in mybfloat, I cut off all + # but 7 bits of the result's mantissa. I then copied that value into + # z_expected. + x = tvm.nd.array( + np.array([1.2348297, -1.0298302E25, 1.2034023E-30]).astype("float32"), + ctx=ctx) + y = tvm.nd.array( + np.array([-2.4992788, -9.888288E19, 9.342338E-29]).astype("float32"), + ctx=ctx) + z_expected = np.array([-1.25, -1.027587E25, + 9.426888E-29]).astype("float32") + z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx) + + built_cast(x, y, z) + + assert np.array_equal(z_expected, z.asnumpy()) + + +def test_bfloat_add_and_cast_FloatImm(): + X = tvm.placeholder((3, ), name="X") + Z = topi.cast( + topi.add( + topi.cast(X, dtype="custom[bfloat]16"), + tvm.expr.FloatImm("custom[bfloat]16", 1.5)), + dtype="float") + + s = tvm.create_schedule([Z.op]) + built_cast = lower_datatypes_and_build(s, [X,Z]) + + ctx = tvm.context(tgt, 0) + + x = tvm.nd.array(np.array([0.0, 1.0, 1.5]).astype("float32"), ctx=ctx) + z_expected = np.array([1.5, 2.5, 3.0]).astype("float32") + z = tvm.nd.empty(Z.shape, dtype=Z.dtype, ctx=ctx) + + built_cast(x, z) + + assert np.array_equal(z_expected, z.asnumpy()) + + +if __name__ == "__main__": + setup() + test_bfloat_add_and_cast_1() + test_bfloat_add_and_cast_2() + test_bfloat_add_and_cast_FloatImm()