From 4942219c638d8613ba6a0132099b7778582811f2 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sat, 11 Apr 2020 21:13:33 -0700 Subject: [PATCH 1/4] [RUNTIME] Allow non-nullable ObjectRef, introduce Optional. We use ObjectRef and their sub-classes extensively throughout our codebase. Each of ObjectRef's sub-classes are nullable, which means they can hold nullptr as their values. While in some places we need nullptr as an alternative value. The implicit support for nullptr in all ObjectRef creates additional burdens for the developer to explicitly check defined in many places of the codebase. Moreover, it is unclear from the API's intentional point of view whether we want a nullable object or not-null version(many cases we want the later). Borrowing existing wisdoms from languages like Rust. We propose to introduce non-nullable ObjectRef, and Optional container that represents a nullable variant. To keep backward compatiblity, we will start by allowing most ObjectRef to be nullable. However, we should start to use Optional as the type in places where we know nullable is a requirement. Gradually, we will move most of the ObjectRef to be non-nullable and use Optional in the nullable cases. Such explicitness in typing can help reduce the potential problems in our codebase overall. Changes in this PR: - Introduce _type_is_nullable attribute to ObjectRef - Introduce Optional - Change String to be non-nullable. - Change the API of function->GetAttr to return Optional --- include/tvm/ir/attrs.h | 2 + include/tvm/ir/expr.h | 67 ++++++++- include/tvm/ir/function.h | 19 ++- include/tvm/node/node.h | 1 - include/tvm/runtime/container.h | 131 +++++++++++++++++- include/tvm/runtime/object.h | 64 +++++++-- include/tvm/runtime/packed_func.h | 8 +- src/driver/driver_api.cc | 10 +- src/relay/backend/compile_engine.cc | 6 +- .../backend/contrib/codegen_c/codegen_c.h | 2 +- src/relay/backend/vm/compiler.cc | 2 +- src/relay/backend/vm/lambda_lift.cc | 2 +- src/relay/ir/transform.cc | 4 +- src/relay/transforms/annotate_target.cc | 6 +- src/target/build_common.h | 6 +- src/target/llvm/codegen_cpu.cc | 2 +- src/target/llvm/codegen_llvm.cc | 4 +- src/target/llvm/llvm_module.cc | 2 +- src/target/opt/build_cuda_on.cc | 3 +- src/target/source/codegen_aocl.cc | 3 +- src/target/source/codegen_c.cc | 2 +- src/target/source/codegen_metal.cc | 12 +- src/target/source/codegen_opencl.cc | 3 +- src/target/source/codegen_opengl.cc | 5 +- src/target/source/codegen_vhls.cc | 5 +- src/target/spirv/build_vulkan.cc | 5 +- src/target/spirv/codegen_spirv.cc | 3 +- src/target/stackvm/codegen_stackvm.cc | 2 +- src/tir/analysis/verify_memory.cc | 2 +- src/tir/transforms/bind_device_type.cc | 2 +- src/tir/transforms/lower_custom_datatypes.cc | 2 +- src/tir/transforms/lower_intrin.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 2 +- src/tir/transforms/lower_warp_memory.cc | 2 +- src/tir/transforms/make_packed_api.cc | 9 +- src/tir/transforms/remap_thread_axis.cc | 5 +- src/tir/transforms/split_host_device.cc | 4 +- tests/cpp/container_test.cc | 67 +++++++++ 38 files changed, 393 insertions(+), 85 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 0fc832e0fb7a..7022b948781e 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -85,6 +85,8 @@ namespace tvm { */ template inline TObjectRef NullValue() { + static_assert(TObjectRef::_type_is_nullable, + "Can only value for nullable types"); return TObjectRef(ObjectPtr(nullptr)); } diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index 4e0a301156a3..9d77136e14ff 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -317,6 +317,47 @@ class FloatImm : public PrimExpr { TVM_DEFINE_OBJECT_REF_METHODS(FloatImm, PrimExpr, FloatImmNode); }; +/*! + * \brief Boolean constant. + * + * This reference type is useful to add additional compile-time + * type checks and helper functions for Integer equal comparisons. + */ +class Bool : public IntImm { + public: + explicit Bool(bool value) + : IntImm(DataType::Bool(), value) { + } + Bool operator!() const { + return Bool((*this)->value == 0); + } + operator bool() const { + return (*this)->value != 0; + } + + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Bool, IntImm, IntImmNode); +}; + +// Overload operators to make sure we have the most fine grained types. +inline Bool operator||(const Bool& a, bool b) { + return Bool(a.operator bool() || b); +} +inline Bool operator||(bool a, const Bool& b) { + return Bool(a || b.operator bool()); +} +inline Bool operator||(const Bool& a, const Bool& b) { + return Bool(a.operator bool() || b.operator bool()); +} +inline Bool operator&&(const Bool& a, bool b) { + return Bool(a.operator bool() && b); +} +inline Bool operator&&(bool a, const Bool& b) { + return Bool(a && b.operator bool()); +} +inline Bool operator&&(const Bool& a, const Bool& b) { + return Bool(a.operator bool() && b.operator bool()); +} + /*! * \brief Container of constant int that adds more constructors. * @@ -346,10 +387,10 @@ class Integer : public IntImm { * \tparam Enum The enum type. * \param value The enum value. */ - template::value>::type> - explicit Integer(ENum value) : Integer(static_cast(value)) { - static_assert(std::is_same::type>::value, + template::value>::type> + explicit Integer(Enum value) : Integer(static_cast(value)) { + static_assert(std::is_same::type>::value, "declare enum to be enum int to use visitor"); } /*! @@ -368,6 +409,24 @@ class Integer : public IntImm { << " Trying to reference a null Integer"; return (*this)->value; } + // comparators + Bool operator==(int other) const { + if (data_ == nullptr) return Bool(false); + return Bool((*this)->value == other); + } + Bool operator!=(int other) const { + return !(*this == other); + } + template::value>::type> + Bool operator==(Enum other) const { + return *this == static_cast(other); + } + template::value>::type> + Bool operator!=(Enum other) const { + return *this != static_cast(other); + } }; /*! \brief range over one dimension */ diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h index dc7a2b218568..d55656f34b00 100644 --- a/include/tvm/ir/function.h +++ b/include/tvm/ir/function.h @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -90,25 +91,31 @@ class BaseFuncNode : public RelayExprNode { * \code * * void GetAttrExample(const BaseFunc& f) { - * Integer value = f->GetAttr("AttrKey", 0); + * auto value = f->GetAttr("AttrKey", 0); * } * * \endcode */ template - TObjectRef GetAttr(const std::string& attr_key, - TObjectRef default_value = NullValue()) const { + Optional GetAttr( + const std::string& attr_key, + Optional default_value = Optional(nullptr)) const { static_assert(std::is_base_of::value, "Can only call GetAttr with ObjectRef types."); if (!attrs.defined()) return default_value; auto it = attrs->dict.find(attr_key); if (it != attrs->dict.end()) { - return Downcast((*it).second); + return Downcast>((*it).second); } else { return default_value; } } - + // variant that uses TObjectRef to enable implicit conversion to default value. + template + Optional GetAttr( + const std::string& attr_key, TObjectRef default_value) const { + return GetAttr(attr_key, Optional(default_value)); + } /*! * \brief Check whether the function has an non-zero integer attr. * @@ -129,7 +136,7 @@ class BaseFuncNode : public RelayExprNode { * \endcode */ bool HasNonzeroAttr(const std::string& attr_key) const { - return GetAttr(attr_key, 0)->value != 0; + return GetAttr(attr_key, 0) != 0; } static constexpr const char* _type_key = "BaseFunc"; diff --git a/include/tvm/node/node.h b/include/tvm/node/node.h index b39e3b403421..471a0de361b7 100644 --- a/include/tvm/node/node.h +++ b/include/tvm/node/node.h @@ -63,7 +63,6 @@ using runtime::make_object; using runtime::PackedFunc; using runtime::TVMArgs; using runtime::TVMRetValue; -using runtime::String; } // namespace tvm #endif // TVM_NODE_NODE_H_ diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 8963f0921276..3424be890b85 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -353,6 +353,10 @@ class StringObj : public Object { */ class String : public ObjectRef { public: + /*! + * \brief Construct an empty string. + */ + String() : String(std::string()) {} /*! * \brief Construct a new String object * @@ -467,9 +471,6 @@ class String : public ObjectRef { */ size_t size() const { const auto* ptr = get(); - if (ptr == nullptr) { - return 0; - } return ptr->size; } @@ -524,7 +525,7 @@ class String : public ObjectRef { /*! \return the internal StringObj pointer */ const StringObj* get() const { return operator->(); } - TVM_DEFINE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(String, ObjectRef, StringObj); private: /*! @@ -610,7 +611,129 @@ struct PackedFuncValueConverter<::tvm::runtime::String> { } }; +/*! + * \brief Optional container that to represent to a Nullable variant of T. + * \tparam T The original ObjectRef. + * + * \code + * + * Optional opt0 = nullptr; + * Optional opt1 = String("xyz"); + * CHECK(opt0 == nullptr); + * CHECK(opt1 == "xyz"); + * + * \endcode + */ +template +class Optional : public ObjectRef { + public: + using ContainerType = typename T::ContainerType; + static_assert(std::is_base_of::value, + "Optional is only defined for ObjectRef."); + // default constructors. + Optional() = default; + Optional(const Optional&) = default; + Optional(Optional&&) = default; + Optional& operator=(const Optional&) = default; + Optional& operator=(Optional&&) = default; + /*! + * \brief Construct from an ObjectPtr + * whose type already matches the ContainerType. + * \param ptr + */ + explicit Optional(ObjectPtr ptr) : ObjectRef(ptr) {} + // nullptr handling. + // disallow implicit conversion as 0 can be implicitly converted to nullptr_t + explicit Optional(std::nullptr_t) {} + Optional& operator=(std::nullptr_t) { + data_ = nullptr; + return *this; + } + // normal value handling. + Optional(T other) // NOLINT(*) + : ObjectRef(std::move(other)) { + } + Optional& operator=(T other) { + ObjectRef::operator=(std::move(other)); + return *this; + } + // delete the int constructor + // since Optional(0) is ambiguious + // 0 can be implicitly casted to nullptr_t + explicit Optional(int val) = delete; + Optional& operator=(int val) = delete; + /*! + * \return A not-null container value in the optional. + * \note This function performs not-null checking. + */ + T value() const { + CHECK(data_ != nullptr); + return T(data_); + } + /*! \return Whether the container is not nullptr.*/ + explicit operator bool() const { + return *this != nullptr; + } + // operator overloadings + bool operator==(std::nullptr_t) const { + return data_ == nullptr; + } + bool operator!=(std::nullptr_t) const { + return data_ != nullptr; + } + auto operator==(const Optional& other) const { + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() == other.value()); + if (same_as(other)) return RetType(true); + if (*this != nullptr && other != nullptr) { + return value() == other.value(); + } else { + // one of them is nullptr. + return RetType(false); + } + } + auto operator!=(const Optional& other) const { + return !(*this == other); + } + auto operator==(const T& other) const { + using RetType = decltype(value() == other); + if (same_as(other)) return RetType(true); + if (*this != nullptr) return value() == other; + return RetType(false); + } + auto operator!=(const T& other) const { + return !(*this == other); + } + template + auto operator==(const U& other) const { + using RetType = decltype(value() == other); + if (*this == nullptr) return RetType(false); + return value() == other; + } + template + auto operator!=(const U& other) const { + return !(*this == other); + } + static constexpr bool _type_is_nullable = true; +}; + +template +struct PackedFuncValueConverter> { + static Optional From(const TVMArgValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } + static Optional From(const TVMRetValue& val) { + if (val.type_code() == kTVMNullptr) return Optional(nullptr); + return PackedFuncValueConverter::From(val); + } +}; + } // namespace runtime + +// expose the functions to the root namespace. +using runtime::String; +using runtime::Optional; } // namespace tvm namespace std { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index acbb9398b74c..edca925baeb0 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -546,7 +546,9 @@ class ObjectRef { bool operator<(const ObjectRef& other) const { return data_.get() < other.data_.get(); } - /*! \return whether the expression is null */ + /*! + * \return whether the object is defined(not null). + */ bool defined() const { return data_ != nullptr; } @@ -582,6 +584,8 @@ class ObjectRef { /*! \brief type indicate the container type. */ using ContainerType = Object; + // Default type properties for the reference class. + static constexpr bool _type_is_nullable = true; protected: /*! \brief Internal pointer that backs the reference. */ @@ -720,6 +724,17 @@ struct ObjectEqual { TVM_STR_CONCAT(TVM_OBJECT_REG_VAR_DEF, __COUNTER__) = \ TypeName::_GetOrAllocRuntimeTypeIndex() + +/* + * \brief Define the default copy/move constructor and assign opeator + * \param TypeName The class typename. + */ +#define TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName) \ + TypeName(const TypeName& other) = default; \ + TypeName(TypeName&& other) = default; \ + TypeName& operator=(const TypeName& other) = default; \ + TypeName& operator=(TypeName&& other) = default; \ + /* * \brief Define object reference methods. * \param TypeName The object type name @@ -727,15 +742,34 @@ struct ObjectEqual { * \param ObjectName The type name of the object. */ #define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() {} \ + TypeName() = default; \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ const ObjectName* operator->() const { \ return static_cast(data_.get()); \ } \ using ContainerType = ObjectName; +/* + * \brief Define object reference methods that is not nullable. + * + * \param TypeName The object type name + * \param ParentType The parent type of the objectref + * \param ObjectName The type name of the object. + */ +#define TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + explicit TypeName( \ + ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ + : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { \ + return static_cast(data_.get()); \ + } \ + static constexpr bool _type_is_nullable = false; \ + using ContainerType = ObjectName; + /* * \brief Define object reference methods of whose content is mutable. * \param TypeName The object type name @@ -745,7 +779,8 @@ struct ObjectEqual { * This macro is only reserved for objects that stores runtime states. */ #define TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() {} \ + TypeName() = default; \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ explicit TypeName( \ ::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) \ : ParentType(n) {} \ @@ -869,11 +904,14 @@ inline const ObjectType* ObjectRef::as() const { } } -template -inline RelayRefType GetRef(const ObjType* ptr) { - static_assert(std::is_base_of::value, +template +inline RefType GetRef(const ObjType* ptr) { + static_assert(std::is_base_of::value, "Can only cast to the ref of same container type"); - return RelayRefType(ObjectPtr(const_cast(static_cast(ptr)))); + if (!RefType::_type_is_nullable) { + CHECK(ptr != nullptr); + } + return RefType(ObjectPtr(const_cast(static_cast(ptr)))); } template @@ -885,9 +923,15 @@ inline ObjectPtr GetObjectPtr(ObjType* ptr) { template inline SubRef Downcast(BaseRef ref) { - CHECK(!ref.defined() || ref->template IsInstance()) - << "Downcast from " << ref->GetTypeKey() << " to " - << SubRef::ContainerType::_type_key << " failed."; + if (ref.defined()) { + CHECK(ref->template IsInstance()) + << "Downcast from " << ref->GetTypeKey() << " to " + << SubRef::ContainerType::_type_key << " failed."; + } else { + CHECK(SubRef::_type_is_nullable) + << "Downcast from nullptr to not nullable reference of " + << SubRef::ContainerType::_type_key; + } return SubRef(std::move(ref.data_)); } diff --git a/include/tvm/runtime/packed_func.h b/include/tvm/runtime/packed_func.h index c5f0df57b10c..3d5a7e865303 100644 --- a/include/tvm/runtime/packed_func.h +++ b/include/tvm/runtime/packed_func.h @@ -352,7 +352,7 @@ template struct ObjectTypeChecker { static bool Check(const Object* ptr) { using ContainerType = typename T::ContainerType; - if (ptr == nullptr) return true; + if (ptr == nullptr) return T::_type_is_nullable; return ptr->IsInstance(); } static std::string TypeName() { @@ -1400,7 +1400,11 @@ inline TObjectRef TVMPODValue_::AsObjectRef() const { std::is_base_of::value, "Conversion only works for ObjectRef"); using ContainerType = typename TObjectRef::ContainerType; - if (type_code_ == kTVMNullptr) return TObjectRef(ObjectPtr(nullptr)); + if (type_code_ == kTVMNullptr) { + CHECK(TObjectRef::_type_is_nullable) + << "Expect a not null value of " << ContainerType::_type_key; + return TObjectRef(ObjectPtr(nullptr)); + } // NOTE: the following code can be optimized by constant folding. if (std::is_base_of::value) { // Casting to a sub-class of NDArray diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index f939f0a1e7d6..d7955a2ca620 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -244,8 +244,9 @@ split_dev_host_funcs(IRModule mod_mixed, auto host_pass_list = { FilterBy([](const tir::PrimFunc& f) { - int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; - return value != static_cast(CallingConv::kDeviceKernelLaunch); + return f->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) != CallingConv::kDeviceKernelLaunch; }), BindTarget(target_host), tir::transform::LowerTVMBuiltin(), @@ -259,8 +260,9 @@ split_dev_host_funcs(IRModule mod_mixed, // device pipeline auto device_pass_list = { FilterBy([](const tir::PrimFunc& f) { - int64_t value = f->GetAttr(tvm::attr::kCallingConv, 0)->value; - return value == static_cast(CallingConv::kDeviceKernelLaunch); + return f->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDeviceKernelLaunch; }), BindTarget(target), tir::transform::LowerWarpMemory(), diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 9cb6b2efe28c..4ed8fbc15abd 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -620,14 +620,14 @@ class CompileEngineImpl : public CompileEngineNode { if (src_func->GetAttr(attr::kCompiler).defined()) { auto code_gen = src_func->GetAttr(attr::kCompiler); CHECK(code_gen.defined()) << "No external codegen is set"; - std::string code_gen_name = code_gen; + std::string code_gen_name = code_gen.value(); if (ext_mods.find(code_gen_name) == ext_mods.end()) { ext_mods[code_gen_name] = IRModule({}, {}); } auto symbol_name = src_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(symbol_name.defined()) << "No external symbol is set for:\n" << AsText(src_func, false); - auto gv = GlobalVar(std::string(symbol_name)); + auto gv = GlobalVar(std::string(symbol_name.value())); ext_mods[code_gen_name]->Add(gv, src_func); cached_ext_funcs.push_back(it.first); } @@ -698,7 +698,7 @@ class CompileEngineImpl : public CompileEngineNode { key->source_func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "External function has not been attached a name yet."; - cache_node->func_name = std::string(name_node); + cache_node->func_name = std::string(name_node.value()); cache_node->target = tvm::target::ext_dev(); value->cached_func = CachedFunc(cache_node); return value; diff --git a/src/relay/backend/contrib/codegen_c/codegen_c.h b/src/relay/backend/contrib/codegen_c/codegen_c.h index 1b953f3c4467..7dfa4bac06ed 100644 --- a/src/relay/backend/contrib/codegen_c/codegen_c.h +++ b/src/relay/backend/contrib/codegen_c/codegen_c.h @@ -72,7 +72,7 @@ class CSourceModuleCodegenBase { const auto name_node = func->GetAttr(tvm::attr::kGlobalSymbol); CHECK(name_node.defined()) << "Fail to retrieve external symbol."; - return std::string(name_node); + return std::string(name_node.value()); } }; diff --git a/src/relay/backend/vm/compiler.cc b/src/relay/backend/vm/compiler.cc index e2b0fffec8bd..8af6247fc810 100644 --- a/src/relay/backend/vm/compiler.cc +++ b/src/relay/backend/vm/compiler.cc @@ -446,7 +446,7 @@ class VMFunctionCompiler : ExprFunctor { const Expr& outputs) { std::vector argument_registers; - CHECK_NE(func->GetAttr(attr::kPrimitive, 0)->value, 0) + CHECK(func->GetAttr(attr::kPrimitive, 0) != 0) << "internal error: invoke_tvm_op requires the first argument to be a relay::Function"; auto input_tuple = inputs.as(); diff --git a/src/relay/backend/vm/lambda_lift.cc b/src/relay/backend/vm/lambda_lift.cc index 59c549cabfee..bfbefd57a310 100644 --- a/src/relay/backend/vm/lambda_lift.cc +++ b/src/relay/backend/vm/lambda_lift.cc @@ -45,7 +45,7 @@ inline std::string GenerateName(const Function& func) { } bool IsClosure(const Function& func) { - return func->GetAttr(attr::kClosure, 0)->value != 0; + return func->GetAttr(attr::kClosure, 0) != 0; } Function MarkClosure(Function func) { diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index a06eb5a4d347..06dd2b16661f 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -145,8 +145,8 @@ IRModule FunctionPassNode::operator()(IRModule mod, } bool FunctionPassNode::SkipFunction(const Function& func) const { - return func->GetAttr(attr::kSkipOptimization, 0)->value != 0 || - (func->GetAttr(attr::kCompiler).defined()); + return (func->GetAttr(attr::kCompiler).defined()) || + func->GetAttr(attr::kSkipOptimization, 0) != 0; } Pass CreateFunctionPass( diff --git a/src/relay/transforms/annotate_target.cc b/src/relay/transforms/annotate_target.cc index 44d7b54e9637..2499982e321a 100644 --- a/src/relay/transforms/annotate_target.cc +++ b/src/relay/transforms/annotate_target.cc @@ -158,9 +158,9 @@ class AnnotateTargetWrapper : public ExprMutator { // if it is in the target list. Function func = Downcast(cn->op); CHECK(func.defined()); - auto comp_name = func->GetAttr(attr::kComposite); - if (comp_name.defined()) { - std::string comp_name_str = comp_name; + + if (auto comp_name = func->GetAttr(attr::kComposite)) { + std::string comp_name_str = comp_name.value(); size_t i = comp_name_str.find('.'); if (i != std::string::npos) { std::string comp_target = comp_name_str.substr(0, i); diff --git a/src/target/build_common.h b/src/target/build_common.h index 5ba51da4ce67..93687c2578ac 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -51,14 +51,14 @@ ExtractFuncInfo(const IRModule& mod) { for (size_t i = 0; i < f->params.size(); ++i) { info.arg_types.push_back(f->params[i].dtype()); } - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - if (thread_axis.defined()) { + if (auto opt = f->GetAttr>(tir::attr::kDeviceThreadAxis)) { + auto thread_axis = opt.value(); for (size_t i = 0; i < thread_axis.size(); ++i) { info.thread_axis_tags.push_back(thread_axis[i]->thread_tag); } } auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); - fmap[static_cast(global_symbol)] = info; + fmap[static_cast(global_symbol.value())] = info; } return fmap; } diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index a863056e8226..ad09730e6db8 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -130,7 +130,7 @@ void CodeGenCPU::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; export_system_symbols_.emplace_back( - std::make_pair(global_symbol.operator std::string(), + std::make_pair(global_symbol.value().operator std::string(), builder_->CreatePointerCast(function_, t_void_p_))); } AddDebugInformation(function_); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 7112691de1bc..604533933b92 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -131,12 +131,12 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenLLVM: Expect PrimFunc to have the global_symbol attribute"; - CHECK(module_->getFunction(static_cast(global_symbol)) == nullptr) + CHECK(module_->getFunction(static_cast(global_symbol.value())) == nullptr) << "Function " << global_symbol << " already exist in module"; function_ = llvm::Function::Create( ftype, llvm::Function::ExternalLinkage, - global_symbol.operator std::string(), module_.get()); + global_symbol.value().operator std::string(), module_.get()); function_->setCallingConv(llvm::CallingConv::C); function_->setDLLStorageClass(llvm::GlobalValue::DLLStorageClassTypes::DLLExportStorageClass); diff --git a/src/target/llvm/llvm_module.cc b/src/target/llvm/llvm_module.cc index 52dccbaf5eb6..d1a244d01ff4 100644 --- a/src/target/llvm/llvm_module.cc +++ b/src/target/llvm/llvm_module.cc @@ -216,7 +216,7 @@ class LLVMModuleNode final : public runtime::ModuleNode { if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()); - entry_func = global_symbol; + entry_func = global_symbol.value(); } funcs.push_back(f); } diff --git a/src/target/opt/build_cuda_on.cc b/src/target/opt/build_cuda_on.cc index 634fb9a57f27..2d659e4487e3 100644 --- a/src/target/opt/build_cuda_on.cc +++ b/src/target/opt/build_cuda_on.cc @@ -138,8 +138,7 @@ runtime::Module BuildCUDA(IRModule mod) { << "CodeGenCUDA: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenCUDA: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_aocl.cc b/src/target/source/codegen_aocl.cc index c6011cd4dc87..64674e3360dd 100644 --- a/src/target/source/codegen_aocl.cc +++ b/src/target/source/codegen_aocl.cc @@ -45,8 +45,7 @@ runtime::Module BuildAOCL(IRModule mod, << "CodegenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodegenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index a0e18a612055..444dc996b10f 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -84,7 +84,7 @@ void CodeGenC::AddFunction(const PrimFunc& f) { bool no_alias = f->HasNonzeroAttr(tir::attr::kNoAlias); this->PrintFuncPrefix(); - this->stream << " " << static_cast(global_symbol) << "("; + this->stream << " " << static_cast(global_symbol.value()) << "("; for (size_t i = 0; i < f->params.size(); ++i) { tir::Var v = f->params[i]; diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 715c0ae92ddc..ea49d33351a0 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -61,7 +61,7 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; // Function header. - this->stream << "kernel void " << static_cast(global_symbol) << "("; + this->stream << "kernel void " << static_cast(global_symbol.value()) << "("; // Buffer arguments size_t num_buffer = 0; @@ -91,7 +91,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { size_t nargs = f->params.size() - num_buffer; std::string varg = GetUniqueName("arg"); if (nargs != 0) { - std::string arg_buf_type = static_cast(global_symbol) + "_args_t"; + std::string arg_buf_type = + static_cast(global_symbol.value()) + "_args_t"; stream << " constant " << arg_buf_type << "& " << varg << " [[ buffer(" << num_buffer << ") ]],\n"; // declare the struct @@ -120,8 +121,8 @@ void CodeGenMetal::AddFunction(const PrimFunc& f) { CHECK_EQ(GetUniqueName("threadIdx"), "threadIdx"); CHECK_EQ(GetUniqueName("blockIdx"), "blockIdx"); int work_dim = 0; - auto thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); - CHECK(thread_axis.defined()); + auto thread_axis = f->GetAttr>( + tir::attr::kDeviceThreadAxis).value(); for (IterVar iv : thread_axis) { runtime::ThreadScope scope = runtime::ThreadScope::make(iv->thread_tag); @@ -278,8 +279,7 @@ runtime::Module BuildMetal(IRModule mod) { << "CodeGenMetal: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenMetal: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_opencl.cc b/src/target/source/codegen_opencl.cc index 67761c17680a..d5b89609e514 100644 --- a/src/target/source/codegen_opencl.cc +++ b/src/target/source/codegen_opencl.cc @@ -249,8 +249,7 @@ runtime::Module BuildOpenCL(IRModule mod) { << "CodeGenOpenCL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenCL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_opengl.cc b/src/target/source/codegen_opengl.cc index 13d87d282e6c..946b483a1dd9 100644 --- a/src/target/source/codegen_opengl.cc +++ b/src/target/source/codegen_opengl.cc @@ -160,7 +160,7 @@ void CodeGenOpenGL::AddFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenOpenGL: Expect PrimFunc to have the global_symbol attribute"; - shaders_[static_cast(global_symbol)] = runtime::OpenGLShader( + shaders_[static_cast(global_symbol.value())] = runtime::OpenGLShader( this->decl_stream.str() + this->stream.str(), std::move(arg_names), std::move(arg_kinds), this->thread_extent_var_); @@ -299,8 +299,7 @@ runtime::Module BuildOpenGL(IRModule mod) { << "CodeGenOpenGL: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenOpenGL: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } diff --git a/src/target/source/codegen_vhls.cc b/src/target/source/codegen_vhls.cc index 7486164444c4..71c36264afa4 100644 --- a/src/target/source/codegen_vhls.cc +++ b/src/target/source/codegen_vhls.cc @@ -138,8 +138,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { << "CodeGenVHLS: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenVLHS: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; cg.AddFunction(f); } @@ -164,7 +163,7 @@ runtime::Module BuildSDAccel(IRModule mod, std::string target_str) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenC: Expect PrimFunc to have the global_symbol attribute"; - kernel_info.push_back({global_symbol, code}); + kernel_info.push_back({global_symbol.value(), code}); } std::string xclbin; diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index 58721414a665..161c1ca3bab1 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -87,14 +87,13 @@ runtime::Module BuildSPIRV(IRModule mod) { << "CodeGenSPIRV: Can only take PrimFunc"; auto f = Downcast(kv.second); auto calling_conv = f->GetAttr(tvm::attr::kCallingConv); - CHECK(calling_conv.defined() && - calling_conv->value == static_cast(CallingConv::kDeviceKernelLaunch)) + CHECK(calling_conv == CallingConv::kDeviceKernelLaunch) << "CodeGenSPIRV: expect calling_conv equals CallingConv::kDeviceKernelLaunch"; auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol; + std::string f_name = global_symbol.value(); f = PointerValueTypeRewrite(std::move(f)); VulkanShader shader; shader.data = cg.BuildFunction(f); diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index db2a2f359aa4..bfe21b024426 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -82,7 +82,8 @@ std::vector CodeGenSPIRV::BuildFunction(const PrimFunc& f) { CHECK(global_symbol.defined()) << "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute"; - builder_->CommitKernelFunction(func_ptr, static_cast(global_symbol)); + builder_->CommitKernelFunction( + func_ptr, static_cast(global_symbol.value())); return builder_->Finalize(); } diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index da75a70e9123..661fdabd3c32 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -539,7 +539,7 @@ runtime::Module BuildStackVM(const IRModule& mod) { auto global_symbol = f->GetAttr(tvm::attr::kGlobalSymbol); CHECK(global_symbol.defined()) << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute"; - std::string f_name = global_symbol; + std::string f_name = global_symbol.value(); StackVM vm = codegen::CodeGenStackVM().Compile(f); CHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list"; diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index d6a521f98487..9ff4f3d5b738 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -195,7 +195,7 @@ void VerifyMemory(const IRModule& mod) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - MemoryAccessVerifier v(func, target->device_type); + MemoryAccessVerifier v(func, target.value()->device_type); v.Run(); if (v.Failed()) { LOG(FATAL) diff --git a/src/tir/transforms/bind_device_type.cc b/src/tir/transforms/bind_device_type.cc index 952d6635f582..a6db9f9c6da8 100644 --- a/src/tir/transforms/bind_device_type.cc +++ b/src/tir/transforms/bind_device_type.cc @@ -99,7 +99,7 @@ Pass BindDeviceType() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "BindDeviceType: Require the target attribute"; - n->body = DeviceTypeBinder(target->device_type)(std::move(n->body)); + n->body = DeviceTypeBinder(target.value()->device_type)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.BindDeviceType", {}); diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 6026f8c67567..6cf9e3adce96 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -141,7 +141,7 @@ Pass LowerCustomDatatypes() { CHECK(target.defined()) << "LowerCustomDatatypes: Require the target attribute"; - n->body = CustomDatatypesLowerer(target->target_name)(std::move(n->body)); + n->body = CustomDatatypesLowerer(target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerCustomDatatypes", {}); diff --git a/src/tir/transforms/lower_intrin.cc b/src/tir/transforms/lower_intrin.cc index 41a94937d4ce..6ae638f33474 100644 --- a/src/tir/transforms/lower_intrin.cc +++ b/src/tir/transforms/lower_intrin.cc @@ -293,7 +293,7 @@ Pass LowerIntrin() { << "LowerIntrin: Require the target attribute"; arith::Analyzer analyzer; n->body = - IntrinInjecter(&analyzer, target->target_name)(std::move(n->body)); + IntrinInjecter(&analyzer, target.value()->target_name)(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerIntrin", {}); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index c4df2dcdb868..655a0074c7fd 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -348,7 +348,7 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; - n->body = ThreadAllreduceBuilder(target->thread_warp_size)(n->body); + n->body = ThreadAllreduceBuilder(target.value()->thread_warp_size)(n->body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 1921db53cb06..612a8f4d9eef 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -393,7 +393,7 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - n->body = WarpMemoryRewriter(target->thread_warp_size).Rewrite(std::move(n->body)); + n->body = WarpMemoryRewriter(target.value()->thread_warp_size).Rewrite(std::move(n->body)); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index b1dd235bce03..dd4bd6642676 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -48,9 +48,9 @@ inline Stmt MakeAssertEQ(PrimExpr lhs, PrimExpr rhs, std::string msg) { PrimFunc MakePackedAPI(PrimFunc&& func, int num_unpacked_args) { auto global_symbol = func->GetAttr(tvm::attr::kGlobalSymbol); - CHECK(global_symbol.defined()) + CHECK(global_symbol) << "MakePackedAPI: Expect PrimFunc to have the global_symbol attribute"; - std::string name_hint = global_symbol; + std::string name_hint = global_symbol.value(); auto* func_ptr = func.CopyOnWrite(); const Stmt nop = EvaluateNode::make(0); @@ -240,8 +240,9 @@ Pass MakePackedAPI(int num_unpacked_args) { for (const auto& kv : mptr->functions) { if (auto* n = kv.second.as()) { PrimFunc func = GetRef(n); - if (func->GetAttr(tvm::attr::kCallingConv, 0)->value - == static_cast(CallingConv::kDefault)) { + if (func->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDefault) { auto updated_func = MakePackedAPI(std::move(func), num_unpacked_args); updates.push_back({kv.first, updated_func}); } diff --git a/src/tir/transforms/remap_thread_axis.cc b/src/tir/transforms/remap_thread_axis.cc index f3663532e56b..fdcfc4d4702e 100644 --- a/src/tir/transforms/remap_thread_axis.cc +++ b/src/tir/transforms/remap_thread_axis.cc @@ -82,7 +82,10 @@ PrimFunc RemapThreadAxis(PrimFunc&& f, Map thread_map) tmap[kv.first] = kv.second; } - auto thread_axis = f->GetAttr >(tir::attr::kDeviceThreadAxis); + auto opt_thread_axis = f->GetAttr>(tir::attr::kDeviceThreadAxis); + CHECK(opt_thread_axis != nullptr) + << "Require attribute " << tir::attr::kDeviceThreadAxis; + auto thread_axis = opt_thread_axis.value(); auto* n = f.CopyOnWrite(); // replace the thread axis diff --git a/src/tir/transforms/split_host_device.cc b/src/tir/transforms/split_host_device.cc index 792a06157c09..927536b5938e 100644 --- a/src/tir/transforms/split_host_device.cc +++ b/src/tir/transforms/split_host_device.cc @@ -277,7 +277,9 @@ PrimFunc SplitHostDevice(PrimFunc&& func, IRModule* device_mod) { << "SplitHostDevice: Expect PrimFunc to have the global_symbol attribute"; HostDeviceSplitter splitter( - device_mod, target, static_cast(global_symbol)); + device_mod, + target.value(), + static_cast(global_symbol.value())); auto* n = func.CopyOnWrite(); n->body = splitter(std::move(n->body)); diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 063247db09b6..7be9e98295db 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include #include @@ -401,6 +402,72 @@ TEST(String, Cast) { String s2 = Downcast(r); } + +TEST(Optional, Composition) { + Optional opt0(nullptr); + Optional opt1 = String("xyz"); + Optional opt2 = String("xyz1"); + // operator bool + CHECK(!opt0); + CHECK(opt1); + // comparison op + CHECK(opt0 != "xyz"); + CHECK(opt1 == "xyz"); + CHECK(opt1 != nullptr); + CHECK(opt0 == nullptr); + CHECK(opt0 != opt1); + CHECK(opt1 == Optional(String("xyz"))); + CHECK(opt0 == Optional(nullptr)); + opt0 = opt1; + CHECK(opt0 == opt1); + CHECK(opt0.value().same_as(opt1.value())); + opt0 = std::move(opt2); + CHECK(opt0 != opt2); +} + +TEST(Optional, IntCmp) { + Integer val(CallingConv::kDefault); + Optional opt = Integer(0); + CHECK(0 == static_cast(CallingConv::kDefault)); + CHECK(val == CallingConv::kDefault); + CHECK(opt == CallingConv::kDefault); + + // check we can handle implicit 0 to nullptr conversion. + Optional opt1(nullptr); + CHECK(opt1 != 0); + CHECK(opt1 != false); + CHECK(!(opt1 == 0)); +} + +TEST(Optional, PackedCall) { + auto tf = [](Optional s, bool isnull) { + if (isnull) { + CHECK(s == nullptr); + } else { + CHECK(s != nullptr); + } + return s; + }; + auto func = TypedPackedFunc(Optional, bool)>(tf); + func(String("xyz"), false); + func(Optional(nullptr), true); + + auto pf = [](TVMArgs args, TVMRetValue* rv) { + Optional s = args[0]; + bool isnull = args[1]; + if (isnull) { + CHECK(s == nullptr); + } else { + CHECK(s != nullptr); + } + *rv = s; + }; + auto packedfunc = PackedFunc(pf); + CHECK(packedfunc("xyz", false).operator String() == "xyz"); + CHECK(packedfunc("xyz", false).operator Optional() == "xyz"); + CHECK(packedfunc(nullptr, true).operator Optional() == nullptr); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; From ebba40da2592008835d1b80c0e085c6a1e6a6d46 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 12 Apr 2020 17:05:39 -0700 Subject: [PATCH 2/4] Address review comments --- include/tvm/ir/attrs.h | 2 +- include/tvm/runtime/container.h | 7 +++++++ tests/cpp/container_test.cc | 6 ++++-- 3 files changed, 12 insertions(+), 3 deletions(-) diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h index 7022b948781e..d12f1b85114c 100644 --- a/include/tvm/ir/attrs.h +++ b/include/tvm/ir/attrs.h @@ -86,7 +86,7 @@ namespace tvm { template inline TObjectRef NullValue() { static_assert(TObjectRef::_type_is_nullable, - "Can only value for nullable types"); + "Can only get NullValue for nullable types"); return TObjectRef(ObjectPtr(nullptr)); } diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index 3424be890b85..e02cd5360a6b 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -670,6 +670,13 @@ class Optional : public ObjectRef { CHECK(data_ != nullptr); return T(data_); } + /*! + * \return The contained value if the Optional is not null + * otherwise return the default_value. + */ + T value_or(T default_value) const { + return data_ != nullptr ? T(data_) : default_value; + } /*! \return Whether the container is not nullptr.*/ explicit operator bool() const { return *this != nullptr; diff --git a/tests/cpp/container_test.cc b/tests/cpp/container_test.cc index 7be9e98295db..c67df63e6e7e 100644 --- a/tests/cpp/container_test.cc +++ b/tests/cpp/container_test.cc @@ -415,6 +415,8 @@ TEST(Optional, Composition) { CHECK(opt1 == "xyz"); CHECK(opt1 != nullptr); CHECK(opt0 == nullptr); + CHECK(opt0.value_or("abc") == "abc"); + CHECK(opt1.value_or("abc") == "xyz"); CHECK(opt0 != opt1); CHECK(opt1 == Optional(String("xyz"))); CHECK(opt0 == Optional(nullptr)); @@ -449,8 +451,8 @@ TEST(Optional, PackedCall) { return s; }; auto func = TypedPackedFunc(Optional, bool)>(tf); - func(String("xyz"), false); - func(Optional(nullptr), true); + CHECK(func(String("xyz"), false) == "xyz"); + CHECK(func(Optional(nullptr), true) == nullptr); auto pf = [](TVMArgs args, TVMRetValue* rv) { Optional s = args[0]; From 11a48b4d2bc502446b21d983674ab57c567d4698 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 12 Apr 2020 20:22:11 -0700 Subject: [PATCH 3/4] Upgrade all compiler flags to c++14 --- apps/android_camera/app/src/main/jni/Application.mk | 2 +- apps/android_deploy/app/src/main/jni/Application.mk | 4 ++-- apps/android_rpc/app/src/main/jni/Application.mk | 2 +- apps/cpp_rpc/Makefile | 2 +- apps/dso_plugin_module/Makefile | 2 +- apps/extension/Makefile | 2 +- apps/howto_deploy/Makefile | 2 +- apps/howto_deploy/tvm_runtime_pack.cc | 2 +- apps/rocm_rpc/Makefile | 2 +- apps/tf_tvmdsoop/CMakeLists.txt | 2 +- golang/Makefile | 2 +- python/setup.py | 2 +- tests/python/relay/test_external_codegen.py | 2 +- tests/python/relay/test_external_runtime.py | 4 ++-- tests/python/relay/test_pass_annotate_target.py | 2 +- tests/python/relay/test_pass_partition_graph.py | 2 +- tests/python/unittest/test_runtime_module_export.py | 2 +- vta/python/vta/exec/rpc_server.py | 2 +- 18 files changed, 20 insertions(+), 20 deletions(-) diff --git a/apps/android_camera/app/src/main/jni/Application.mk b/apps/android_camera/app/src/main/jni/Application.mk index 95a5a9697bcc..63a79458ef94 100644 --- a/apps/android_camera/app/src/main/jni/Application.mk +++ b/apps/android_camera/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= all APP_STL := c++_shared -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/android_deploy/app/src/main/jni/Application.mk b/apps/android_deploy/app/src/main/jni/Application.mk index ee13eb8a1213..a50a40bf5cd1 100644 --- a/apps/android_deploy/app/src/main/jni/Application.mk +++ b/apps/android_deploy/app/src/main/jni/Application.mk @@ -27,7 +27,7 @@ include $(config) APP_STL := c++_static -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti -ifeq ($(USE_OPENCL), 1) +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti +ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/android_rpc/app/src/main/jni/Application.mk b/apps/android_rpc/app/src/main/jni/Application.mk index 56288bde9898..54abdf771e2a 100644 --- a/apps/android_rpc/app/src/main/jni/Application.mk +++ b/apps/android_rpc/app/src/main/jni/Application.mk @@ -31,7 +31,7 @@ include $(config) APP_ABI ?= armeabi-v7a arm64-v8a x86 x86_64 mips APP_STL := c++_shared -APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++11 -Oz -frtti +APP_CPPFLAGS += -DDMLC_LOG_STACK_TRACE=0 -DTVM4J_ANDROID=1 -std=c++14 -Oz -frtti ifeq ($(USE_OPENCL), 1) APP_CPPFLAGS += -DTVM_OPENCL_RUNTIME=1 endif diff --git a/apps/cpp_rpc/Makefile b/apps/cpp_rpc/Makefile index 9cd39b446acc..927331ad00ea 100644 --- a/apps/cpp_rpc/Makefile +++ b/apps/cpp_rpc/Makefile @@ -28,7 +28,7 @@ else LINK_PTHREAD= endif -PKG_CFLAGS = -std=c++11 -O2 -fPIC -Wall\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC -Wall\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/dso_plugin_module/Makefile b/apps/dso_plugin_module/Makefile index 2ee6189e2876..c2ce3306870a 100644 --- a/apps/dso_plugin_module/Makefile +++ b/apps/dso_plugin_module/Makefile @@ -16,7 +16,7 @@ # under the License. TVM_ROOT=$(shell cd ../..; pwd) -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/extension/Makefile b/apps/extension/Makefile index e178b661f403..91d914aba63b 100644 --- a/apps/extension/Makefile +++ b/apps/extension/Makefile @@ -17,7 +17,7 @@ # Minimum Makefile for the extension package TVM_ROOT=$(shell cd ../..; pwd) -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${TVM_ROOT}/3rdparty/dmlc-core/include\ -I${TVM_ROOT}/3rdparty/dlpack/include diff --git a/apps/howto_deploy/Makefile b/apps/howto_deploy/Makefile index a260e89bc042..4ee243c2ce60 100644 --- a/apps/howto_deploy/Makefile +++ b/apps/howto_deploy/Makefile @@ -19,7 +19,7 @@ TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\ diff --git a/apps/howto_deploy/tvm_runtime_pack.cc b/apps/howto_deploy/tvm_runtime_pack.cc index d166eaf756a5..81bab497bebb 100644 --- a/apps/howto_deploy/tvm_runtime_pack.cc +++ b/apps/howto_deploy/tvm_runtime_pack.cc @@ -24,7 +24,7 @@ * include in your project. * * - Copy this file into your project which depends on tvm runtime. - * - Compile with -std=c++11 + * - Compile with -std=c++14 * - Add the following include path * - /path/to/tvm/include/ * - /path/to/tvm/3rdparty/dmlc-core/include/ diff --git a/apps/rocm_rpc/Makefile b/apps/rocm_rpc/Makefile index 36eb41596be8..971ca4603314 100644 --- a/apps/rocm_rpc/Makefile +++ b/apps/rocm_rpc/Makefile @@ -21,7 +21,7 @@ ROCM_PATH=/opt/rocm TVM_ROOT=$(shell cd ../..; pwd) DMLC_CORE=${TVM_ROOT}/3rdparty/dmlc-core -PKG_CFLAGS = -std=c++11 -O2 -fPIC\ +PKG_CFLAGS = -std=c++14 -O2 -fPIC\ -I${TVM_ROOT}/include\ -I${DMLC_CORE}/include\ -I${TVM_ROOT}/3rdparty/dlpack/include\ diff --git a/apps/tf_tvmdsoop/CMakeLists.txt b/apps/tf_tvmdsoop/CMakeLists.txt index cb601ef6d30d..f4e83c528701 100644 --- a/apps/tf_tvmdsoop/CMakeLists.txt +++ b/apps/tf_tvmdsoop/CMakeLists.txt @@ -17,7 +17,7 @@ cmake_minimum_required(VERSION 3.2) project(tf_tvmdsoop C CXX) -set(TFTVM_COMPILE_FLAGS -std=c++11) +set(TFTVM_COMPILE_FLAGS -std=c++14) set(BUILD_TVMDSOOP_ONLY ON) set(CMAKE_CURRENT_SOURCE_DIR ${TVM_ROOT}) set(CMAKE_CURRENT_BINARY_DIR ${TVM_ROOT}/build) diff --git a/golang/Makefile b/golang/Makefile index c54fd0e0992c..6fd77996e119 100644 --- a/golang/Makefile +++ b/golang/Makefile @@ -25,7 +25,7 @@ NATIVE_SRC = tvm_runtime_pack.cc GOPATH=$(CURDIR)/gopath GOPATHDIR=${GOPATH}/src/${TARGET}/ CGO_CPPFLAGS="-I. -I${TVM_BASE}/ -I${TVM_BASE}/3rdparty/dmlc-core/include -I${TVM_BASE}/include -I${TVM_BASE}/3rdparty/dlpack/include/" -CGO_CXXFLAGS="-std=c++11" +CGO_CXXFLAGS="-std=c++14" CGO_CFLAGS="-I${TVM_BASE}" CGO_LDFLAGS="-ldl -lm" diff --git a/python/setup.py b/python/setup.py index 937d682e3c85..62f374923714 100644 --- a/python/setup.py +++ b/python/setup.py @@ -96,7 +96,7 @@ def config_cython(): "../3rdparty/dmlc-core/include", "../3rdparty/dlpack/include", ], - extra_compile_args=["-std=c++11"], + extra_compile_args=["-std=c++14"], library_dirs=library_dirs, libraries=libraries, language="c++")) diff --git a/tests/python/relay/test_external_codegen.py b/tests/python/relay/test_external_codegen.py index b4496bb044ba..3797910080a1 100644 --- a/tests/python/relay/test_external_codegen.py +++ b/tests/python/relay/test_external_codegen.py @@ -39,7 +39,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/relay/test_external_runtime.py b/tests/python/relay/test_external_runtime.py index 397c35db53d9..39209232f3d0 100644 --- a/tests/python/relay/test_external_runtime.py +++ b/tests/python/relay/test_external_runtime.py @@ -468,13 +468,13 @@ def run_extern(label, get_extern_src, **kwargs): def test_dso_extern(): - run_extern("lib", generate_csource_module, options=["-O2", "-std=c++11"]) + run_extern("lib", generate_csource_module, options=["-O2", "-std=c++14"]) def test_engine_extern(): run_extern("engine", generate_engine_module, - options=["-O2", "-std=c++11", "-I" + tmp_path.relpath("")]) + options=["-O2", "-std=c++14", "-I" + tmp_path.relpath("")]) def test_json_extern(): if not tvm.get_global_func("module.loadfile_examplejson", True): diff --git a/tests/python/relay/test_pass_annotate_target.py b/tests/python/relay/test_pass_annotate_target.py index 705a2614674a..01ba9b619205 100644 --- a/tests/python/relay/test_pass_annotate_target.py +++ b/tests/python/relay/test_pass_annotate_target.py @@ -42,7 +42,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/relay/test_pass_partition_graph.py b/tests/python/relay/test_pass_partition_graph.py index c7d9626931d0..1d0cc5b79a44 100644 --- a/tests/python/relay/test_pass_partition_graph.py +++ b/tests/python/relay/test_pass_partition_graph.py @@ -182,7 +182,7 @@ def update_lib(lib): contrib_path = os.path.join(source_dir, "src", "runtime", "contrib") kwargs = {} - kwargs["options"] = ["-O2", "-std=c++11", "-I" + contrib_path] + kwargs["options"] = ["-O2", "-std=c++14", "-I" + contrib_path] tmp_path = util.tempdir() lib_name = 'lib.so' lib_path = tmp_path.relpath(lib_name) diff --git a/tests/python/unittest/test_runtime_module_export.py b/tests/python/unittest/test_runtime_module_export.py index 35bafb4ba3c7..fce7d2f350dc 100644 --- a/tests/python/unittest/test_runtime_module_export.py +++ b/tests/python/unittest/test_runtime_module_export.py @@ -191,7 +191,7 @@ def verify_multi_c_mod_export(): path_lib = temp.relpath(file_name) resnet18_cpu_lib.import_module(f) resnet18_cpu_lib.import_module(engine_module) - kwargs = {"options": ["-O2", "-std=c++11", "-I" + header_file_dir_path.relpath("")]} + kwargs = {"options": ["-O2", "-std=c++14", "-I" + header_file_dir_path.relpath("")]} resnet18_cpu_lib.export_library(path_lib, fcompile=False, **kwargs) loaded_lib = tvm.runtime.load_module(path_lib) assert loaded_lib.type_key == "library" diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index 65c1274214e6..9cfd50927041 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -107,7 +107,7 @@ def reconfig_runtime(cfg_json): if pkg.same_config(old_cfg): logging.info("Skip reconfig_runtime due to same config.") return - cflags = ["-O2", "-std=c++11"] + cflags = ["-O2", "-std=c++14"] cflags += pkg.cflags ldflags = pkg.ldflags lib_name = dll_path From 3db3ad87e0e9bdb8ec05cc907cc28f76eeb58f0a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 12 Apr 2020 20:34:58 -0700 Subject: [PATCH 4/4] Update as per review comment --- include/tvm/runtime/container.h | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/include/tvm/runtime/container.h b/include/tvm/runtime/container.h index e02cd5360a6b..8f426415ffee 100644 --- a/include/tvm/runtime/container.h +++ b/include/tvm/runtime/container.h @@ -700,7 +700,15 @@ class Optional : public ObjectRef { } } auto operator!=(const Optional& other) const { - return !(*this == other); + // support case where sub-class returns a symbolic ref type. + using RetType = decltype(value() != other.value()); + if (same_as(other)) return RetType(false); + if (*this != nullptr && other != nullptr) { + return value() != other.value(); + } else { + // one of them is nullptr. + return RetType(true); + } } auto operator==(const T& other) const { using RetType = decltype(value() == other); @@ -719,7 +727,9 @@ class Optional : public ObjectRef { } template auto operator!=(const U& other) const { - return !(*this == other); + using RetType = decltype(value() != other); + if (*this == nullptr) return RetType(true); + return value() != other; } static constexpr bool _type_is_nullable = true; };