diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 0fc832e0fb7ae..7022b948781e6 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -85,6 +85,8 @@ namespace tvm { */ template inline TObjectRef NullValue() { + static_assert(TObjectRef::_type_is_nullable, + "Can only value for nullable types"); return TObjectRef(ObjectPtr(nullptr)); } diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4e0a301156a34..9d77136e14ff6 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -317,6 +317,47 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); }; +/*! + * \brief Boolean constant. + * + * This reference type is useful to add additional compile-time + * type checks and helper functions for Integer equal comparisons. + */ +class Bool : public IntImm { + public: + explicit Bool(bool value) + : IntImm(DataType::Bool(), value) { + } + Bool operator!() const { + return Bool((*this)->value == 0); + } + operator bool() const { + return (*this)->value != 0; + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); +}; + +// Overload operators to make sure we have the most fine grained types. +inline Bool operator||(const Bool& a, bool b) { + return Bool(a.operator bool() || b); +} +inline Bool operator||(bool a, const Bool& b) { + return Bool(a || b.operator bool()); +} +inline Bool operator||(const Bool& a, const Bool& b) { + return Bool(a.operator bool() || b.operator bool()); +} +inline Bool operator&&(const Bool& a, bool b) { + return Bool(a.operator bool() && b); +} +inline Bool operator&&(bool a, const Bool& b) { + return Bool(a && b.operator bool()); +} +inline Bool operator&&(const Bool& a, const Bool& b) { + return Bool(a.operator bool() && b.operator bool()); +} + /*! * \brief Container of constant int that adds more constructors. * @@ -346,10 +387,10 @@ class Integer : public IntImm { * \tparam Enum The enum type. * \param value The enum value. */ - template::value>::type> - explicit Integer(ENum value) : Integer(static_cast(value)) { - static_assert(std::is_same::type>::value, + template::value>::type> + explicit Integer(Enum value) : Integer(static_cast(value)) { + static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); } /*! @@ -368,6 +409,24 @@ class Integer : public IntImm { << " Trying to reference a null Integer"; return (*this)->value; } + // comparators + Bool operator==(int other) const { + if (data_ == nullptr) return Bool(false); + return Bool((*this)->value == other); + } + Bool operator!=(int other) const { + return !(*this == other); + } + template::value>::type> + Bool operator==(Enum other) const { + return *this == static_cast(other); + } + template::value>::type> + Bool operator!=(Enum other) const { + return *this != static_cast(other); + } }; /*! \brief range over one dimension */ diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index dc7a2b2185687..4050fd8ee72e5 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -90,25 +91,31 @@ class BaseFuncNode : public RelayExprNode { * \code * * void GetAttrExample(const BaseFunc& f) { - * Integer value = f->GetAttr("AttrKey", 0); + * auto value = f->GetAttr("AttrKey", 0); * } * * \endcode */ template - TObjectRef GetAttr(const std::string& attr_key, - TObjectRef default_value = NullValue()) const { + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { static_assert(std::is_base_of::value, "Can only call GetAttr with ObjectRef types."); if (!attrs.defined()) return default_value; auto it = attrs->dict.find(attr_key); if (it != attrs->dict.end()) { - return Downcast((*it).second); + return Downcast>((*it).second); } else { return default_value; } } - + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr( + const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } /*! * \brief Check whether the function has an non-zero integer attr. * @@ -129,7 +136,7 @@ class BaseFuncNode : public RelayExprNode { * \endcode */ bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0)->value != 0; + return GetAttr(attr_key) != 0; } static constexpr const char* _type_key = "BaseFunc"; diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index b39e3b4034213..471a0de361b7c 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -63,7 +63,6 @@ using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::String; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 8963f0921276e..336b3589d39c4 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -353,6 +353,10 @@ class StringObj : public Object { */ class String : public ObjectRef { public: + /*! + * \brief Construct an empty string. + */ + String() : String(std::string()) {} /*! * \brief Construct a new String object * @@ -467,9 +471,6 @@ class String : public ObjectRef { */ size_t size() const { const auto* ptr = get(); - if (ptr == nullptr) { - return 0; - } return ptr->size; } @@ -524,7 +525,7 @@ class String : public ObjectRef { /*! \return the internal StringObj pointer */ const StringObj* get() const { return operator->(); } - TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); private: /*! @@ -610,7 +611,129 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { } }; +/*! + * \brief Optional container that to represent to a Nullable variant of T. + * \tparam T The original ObjectRef. + * + * \code + * + * Optional opt0 = nullptr; + * Optional opt1 = String("xyz"); + * CHECK(opt0 == nullptr); + * CHECK(opt1 == "xyz"); + * + * \endcode + */ +template +class Optional : public ObjectRef { + public: + using ContainerType = typename T::ContainerType; + static_assert(std::is_base_of::value, + "Optional is only defined for ObjectRef."); + // default constructors. + Optional() = default; + Optional(const Optional&) = default; + Optional(Optional&&) = default; + Optional& operator=(const Optional&) = default; + Optional& operator=(Optional&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already matches the ContainerType. + * \param ptr + */ + explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit Optional(std::nullptr_t) {} + Optional& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) { + } + Optional& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // delete the int constructor + // since Optional(0) is ambiguious + // 0 can be implicitly casted to nullptr_t + Optional(int val) = delete; + Optional& operator=(int val) = delete; + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value() const { + CHECK(data_ != nullptr); + return T(data_); + } + /*! \return Whether the container is not nullptr.*/ + explicit operator bool() const { + return *this != nullptr; + } + // operator overloadings + bool operator==(std::nullptr_t) const { + return data_ == nullptr; + } + bool operator!=(std::nullptr_t) const { + return data_ != nullptr; + } + auto operator==(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() == other.value()); + if (same_as(other)) return RetType(true); + if (*this != nullptr && other != nullptr) { + return value() == other.value(); + } else { + // one of them is nullptr. + return RetType(false); + } + } + auto operator!=(const Optional& other) const { + return !(*this == other); + } + auto operator==(const T& other) const { + using RetType = decltype(value() == other); + if (same_as(other)) return RetType(true); + if (*this != nullptr) return value() == other; + return RetType(false); + } + auto operator!=(const T& other) const { + return !(*this == other); + } + template + auto operator==(const U& other) const { + using RetType = decltype(value() == other); + if (*this == nullptr) return RetType(false); + return value() == other; + } + template + auto operator!=(const U& other) const { + return !(*this == other); + } + static constexpr bool _type_is_nullable = true; +}; + +template +struct PackedFuncValueConverter> { + static Optional From(const TVMArgValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } + static Optional From(const TVMRetValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } +}; + } // namespace runtime + +// expose the functions to the root namespace. +using runtime::String; +using runtime::Optional; } // namespace tvm namespace std { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index acbb9398b74c7..edca925baeb0d 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -546,7 +546,9 @@ class ObjectRef { bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } - /*! \return whether the expression is null */ + /*! + * \return whether the object is defined(not null). + */ bool defined() const { return data_ != nullptr; } @@ -582,6 +584,8 @@ class ObjectRef { /*! \brief type indicate the container type. */ using ContainerType = Object; + // Default type properties for the reference class. + static constexpr bool _type_is_nullable = true; protected: /*! \brief Internal pointer that backs the reference. */ @@ -720,6 +724,17 @@ struct ObjectEqual { TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TypeName::_GetOrAllocRuntimeTypeIndex() + +/* + * \brief Define the default copy/move constructor and assign opeator + * \param TypeName The class typename. + */ +#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; \ + /* * \brief Define object reference methods. * \param TypeName The object type name @@ -727,15 +742,34 @@ struct ObjectEqual { * \param ObjectName The type name of the object. */ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() {} \ + TypeName() = default; \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { \ return static_cast(data_.get()); \ } \ using ContainerType = ObjectName; +/* + * \brief Define object reference methods that is not nullable. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName( \ + ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ + : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { \ + return static_cast(data_.get()); \ + } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName; + /* * \brief Define object reference methods of whose content is mutable. * \param TypeName The object type name @@ -745,7 +779,8 @@ struct ObjectEqual { * This macro is only reserved for objects that stores runtime states. */ #define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() {} \ + TypeName() = default; \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ @@ -869,11 +904,14 @@ inline const ObjectType* ObjectRef::as() const { } } -template -inline RelayRefType GetRef(const ObjType* ptr) { - static_assert(std::is_base_of::value, +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, "Can only cast to the ref of same container type"); - return RelayRefType(ObjectPtr(const_cast(static_cast(ptr)))); + if (!RefType::_type_is_nullable) { + CHECK(ptr != nullptr); + } + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); } template @@ -885,9 +923,15 @@ inline ObjectPtr GetObjectPtr(ObjType* ptr) { template inline SubRef Downcast(BaseRef ref) { - CHECK(!ref.defined() || ref->template IsInstance()) - << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; + if (ref.defined()) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } else { + CHECK(SubRef::_type_is_nullable) + << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; + } return SubRef(std::move(ref.data_)); } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index c5f0df57b10c5..3d5a7e865303a 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -352,7 +352,7 @@ template struct ObjectTypeChecker { static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; - if (ptr == nullptr) return true; + if (ptr == nullptr) return T::_type_is_nullable; return ptr->IsInstance(); } static std::string TypeName() { @@ -1400,7 +1400,11 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; - if (type_code_ == kTVMNullptr) return TObjectRef(ObjectPtr(nullptr)); + if (type_code_ == kTVMNullptr) { + CHECK(TObjectRef::_type_is_nullable) + << "Expect a not null value of " << ContainerType::_type_key; + return TObjectRef(ObjectPtr(nullptr)); + } // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { // Casting to a sub-class of NDArray diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f939f0a1e7d60..d7955a2ca6208 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -244,8 +244,9 @@ split_dev_host_funcs(IRModule mod_mixed, auto host_pass_list = { FilterBy([](const tir::PrimFunc& f) { - int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; - return value != static_cast(CallingConv::kDeviceKernelLaunch); + return f->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; }), BindTarget(target_host), tir::transform::LowerTVMBuiltin(), @@ -259,8 +260,9 @@ split_dev_host_funcs(IRModule mod_mixed, // device pipeline auto device_pass_list = { FilterBy([](const tir::PrimFunc& f) { - int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; - return value == static_cast(CallingConv::kDeviceKernelLaunch); + return f->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; }), BindTarget(target), tir::transform::LowerWarpMemory(), diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 9cb6b2efe28c0..4ed8fbc15abd7 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -620,14 +620,14 @@ class CompileEngineImpl : public CompileEngineNode { if (src_func->GetAttr(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen; + std::string code_gen_name = code_gen.value(); if (ext_mods.find(code_gen_name) == ext_mods.end()) { ext_mods[code_gen_name] = IRModule({}, {}); } auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVar(std::string(symbol_name)); + auto gv = GlobalVar(std::string(symbol_name.value())); ext_mods[code_gen_name]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } @@ -698,7 +698,7 @@ class CompileEngineImpl : public CompileEngineNode { key->source_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = std::string(name_node); + cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); return value; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 1b953f3c44671..7dfa4bac06ed8 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -72,7 +72,7 @@ class CSourceModuleCodegenBase { const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; - return std::string(name_node); + return std::string(name_node.value()); } }; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e2b0fffec8bde..8af6247fc810f 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -446,7 +446,7 @@ class VMFunctionCompiler : ExprFunctor { const Expr& outputs) { std::vector argument_registers; - CHECK_NE(func->GetAttr(attr::kPrimitive, 0)->value, 0) + CHECK(func->GetAttr(attr::kPrimitive, 0) != 0) << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 59c549cabfee8..bfbefd57a3105 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -45,7 +45,7 @@ inline std::string GenerateName(const Function& func) { } bool IsClosure(const Function& func) { - return func->GetAttr(attr::kClosure, 0)->value != 0; + return func->GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index a06eb5a4d3472..06dd2b16661f1 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -145,8 +145,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, } bool FunctionPassNode::SkipFunction(const Function& func) const { - return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || - (func->GetAttr(attr::kCompiler).defined()); + return (func->GetAttr(attr::kCompiler).defined()) || + func->GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 44d7b54e96374..2499982e321a1 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -158,9 +158,9 @@ class AnnotateTargetWrapper : public ExprMutator { // if it is in the target list. Function func = Downcast(cn->op); CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); - if (comp_name.defined()) { - std::string comp_name_str = comp_name; + + if (auto comp_name = func->GetAttr(attr::kComposite)) { + std::string comp_name_str = comp_name.value(); size_t i = comp_name_str.find('.'); if (i != std::string::npos) { std::string comp_target = comp_name_str.substr(0, i); diff --git a/src/target/build_common.h b/src/target/build_common.h index 5ba51da4ce672..93687c2578acc 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -51,14 +51,14 @@ ExtractFuncInfo(const IRModule& mod) { for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - if (thread_axis.defined()) { + if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { + auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - fmap[static_cast(global_symbol)] = info; + fmap[static_cast(global_symbol.value())] = info; } return fmap; } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index a863056e82267..ad09730e6db83 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -130,7 +130,7 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( - std::make_pair(global_symbol.operator std::string(), + std::make_pair(global_symbol.value().operator std::string(), builder_->CreatePointerCast(function_, t_void_p_))); } AddDebugInformation(function_); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 7112691de1bcb..604533933b922 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -131,12 +131,12 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - CHECK(module_->getFunction(static_cast(global_symbol)) == nullptr) + CHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) << "Function " << global_symbol << " already exist in module"; function_ = llvm::Function::Create( ftype, llvm::Function::ExternalLinkage, - global_symbol.operator std::string(), module_.get()); + global_symbol.value().operator std::string(), module_.get()); function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 52dccbaf5eb61..d1a244d01ff40 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -216,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()); - entry_func = global_symbol; + entry_func = global_symbol.value(); } funcs.push_back(f); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 634fb9a57f27a..2d659e4487e31 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -138,8 +138,7 @@ runtime::Module BuildCUDA(IRModule mod) { << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index c6011cd4dc87f..64674e3360dd2 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -45,8 +45,7 @@ runtime::Module BuildAOCL(IRModule mod, << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a0e18a6120554..444dc996b10fa 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -84,7 +84,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol) << "("; + this->stream << " " << static_cast(global_symbol.value()) << "("; for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 715c0ae92ddca..ea49d33351a06 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -61,7 +61,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - this->stream << "kernel void " << static_cast(global_symbol) << "("; + this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; @@ -91,7 +91,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { size_t nargs = f->params.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { - std::string arg_buf_type = static_cast(global_symbol) + "_args_t"; + std::string arg_buf_type = + static_cast(global_symbol.value()) + "_args_t"; stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; // declare the struct @@ -120,8 +121,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - CHECK(thread_axis.defined()); + auto thread_axis = f->GetAttr>( + tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); @@ -278,8 +279,7 @@ runtime::Module BuildMetal(IRModule mod) { << "CodeGenMetal: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 67761c17680a8..d5b89609e514c 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -249,8 +249,7 @@ runtime::Module BuildOpenCL(IRModule mod) { << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 13d87d282e6cb..946b483a1dd9c 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -160,7 +160,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; - shaders_[static_cast(global_symbol)] = runtime::OpenGLShader( + shaders_[static_cast(global_symbol.value())] = runtime::OpenGLShader( this->decl_stream.str() + this->stream.str(), std::move(arg_names), std::move(arg_kinds), this->thread_extent_var_); @@ -299,8 +299,7 @@ runtime::Module BuildOpenGL(IRModule mod) { << "CodeGenOpenGL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 7486164444c4e..71c36264afa46 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -138,8 +138,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } @@ -164,7 +163,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - kernel_info.push_back({global_symbol, code}); + kernel_info.push_back({global_symbol.value(), code}); } std::string xclbin; diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 58721414a6651..161c1ca3bab10 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -87,14 +87,13 @@ runtime::Module BuildSPIRV(IRModule mod) { << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol; + std::string f_name = global_symbol.value(); f = PointerValueTypeRewrite(std::move(f)); VulkanShader shader; shader.data = cg.BuildFunction(f); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index db2a2f359aa48..bfe21b024426c 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -82,7 +82,8 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - builder_->CommitKernelFunction(func_ptr, static_cast(global_symbol)); + builder_->CommitKernelFunction( + func_ptr, static_cast(global_symbol.value())); return builder_->Finalize(); } diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index da75a70e91232..661fdabd3c32e 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -539,7 +539,7 @@ runtime::Module BuildStackVM(const IRModule& mod) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol; + std::string f_name = global_symbol.value(); StackVM vm = codegen::CodeGenStackVM().Compile(f); CHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index d6a521f984870..9ff4f3d5b7384 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -195,7 +195,7 @@ void VerifyMemory(const IRModule& mod) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - MemoryAccessVerifier v(func, target->device_type); + MemoryAccessVerifier v(func, target.value()->device_type); v.Run(); if (v.Failed()) { LOG(FATAL) diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc index 952d6635f582a..a6db9f9c6da84 100644 --- a/src/tir/transforms/bind_device_type.cc +++ b/src/tir/transforms/bind_device_type.cc @@ -99,7 +99,7 @@ Pass BindDeviceType() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "BindDeviceType: Require the target attribute"; - n->body = DeviceTypeBinder(target->device_type)(std::move(n->body)); + n->body = DeviceTypeBinder(target.value()->device_type)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {}); diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 6026f8c67567f..6cf9e3adce967 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -141,7 +141,7 @@ Pass LowerCustomDatatypes() { CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; - n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body)); + n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 41a94937d4ce5..6ae638f334742 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -293,7 +293,7 @@ Pass LowerIntrin() { << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; n->body = - IntrinInjecter(&analyzer, target->target_name)(std::move(n->body)); + IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index c4df2dcdb868a..655a0074c7fde 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -348,7 +348,7 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; - n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body); + n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 1921db53cb060..612a8f4d9eef4 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -393,7 +393,7 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(std::move(n->body)); + n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index b1dd235bce03f..dd4bd66426767 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -48,9 +48,9 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) + CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; - std::string name_hint = global_symbol; + std::string name_hint = global_symbol.value(); auto* func_ptr = func.CopyOnWrite(); const Stmt nop = EvaluateNode::make(0); @@ -240,8 +240,9 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr(tvm::attr::kCallingConv, 0)->value - == static_cast(CallingConv::kDefault)) { + if (func->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); } diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index f3663532e56ba..fdcfc4d4702e6 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -82,7 +82,10 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) tmap[kv.first] = kv.second; } - auto thread_axis = f->GetAttr >(tir::attr::kDeviceThreadAxis); + auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); + CHECK(opt_thread_axis != nullptr) + << "Require attribute " << tir::attr::kDeviceThreadAxis; + auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); // replace the thread axis diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 792a06157c09c..927536b5938e7 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -277,7 +277,9 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; HostDeviceSplitter splitter( - device_mod, target, static_cast(global_symbol)); + device_mod, + target.value(), + static_cast(global_symbol.value())); auto* n = func.CopyOnWrite(); n->body = splitter(std::move(n->body)); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 063247db09b66..7be9e98295db6 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -401,6 +402,72 @@ TEST(String, Cast) { String s2 = Downcast(r); } + +TEST(Optional, Composition) { + Optional opt0(nullptr); + Optional opt1 = String("xyz"); + Optional opt2 = String("xyz1"); + // operator bool + CHECK(!opt0); + CHECK(opt1); + // comparison op + CHECK(opt0 != "xyz"); + CHECK(opt1 == "xyz"); + CHECK(opt1 != nullptr); + CHECK(opt0 == nullptr); + CHECK(opt0 != opt1); + CHECK(opt1 == Optional(String("xyz"))); + CHECK(opt0 == Optional(nullptr)); + opt0 = opt1; + CHECK(opt0 == opt1); + CHECK(opt0.value().same_as(opt1.value())); + opt0 = std::move(opt2); + CHECK(opt0 != opt2); +} + +TEST(Optional, IntCmp) { + Integer val(CallingConv::kDefault); + Optional opt = Integer(0); + CHECK(0 == static_cast(CallingConv::kDefault)); + CHECK(val == CallingConv::kDefault); + CHECK(opt == CallingConv::kDefault); + + // check we can handle implicit 0 to nullptr conversion. + Optional opt1(nullptr); + CHECK(opt1 != 0); + CHECK(opt1 != false); + CHECK(!(opt1 == 0)); +} + +TEST(Optional, PackedCall) { + auto tf = [](Optional s, bool isnull) { + if (isnull) { + CHECK(s == nullptr); + } else { + CHECK(s != nullptr); + } + return s; + }; + auto func = TypedPackedFunc(Optional, bool)>(tf); + func(String("xyz"), false); + func(Optional(nullptr), true); + + auto pf = [](TVMArgs args, TVMRetValue* rv) { + Optional s = args[0]; + bool isnull = args[1]; + if (isnull) { + CHECK(s == nullptr); + } else { + CHECK(s != nullptr); + } + *rv = s; + }; + auto packedfunc = PackedFunc(pf); + CHECK(packedfunc("xyz", false).operator String() == "xyz"); + CHECK(packedfunc("xyz", false).operator Optional() == "xyz"); + CHECK(packedfunc(nullptr, true).operator Optional() == nullptr); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe";