diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b910d32ceca4..e6753cf8a3bd 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -119,11 +119,11 @@ class PrimExpr : public BaseExpr { */ TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! \return the data type of this expression. */ - DataType dtype() const { return static_cast(get())->dtype; } - TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); + /*! \return the data type of this expression. */ + DataType dtype() const { return operator->()->dtype; } + private: // Internal function for conversion. friend struct runtime::PackedFuncValueConverter; diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 8830653da88c..c6d2f5f0b998 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -363,10 +363,16 @@ class Array : public ObjectRef { using reverse_iterator = ReverseIterAdapter; /*! \return begin iterator */ - iterator begin() const { return iterator(GetArrayNode()->begin()); } + iterator begin() const { + return iterator(nullptr == GetArrayNode() ? static_cast(nullptr) + : GetArrayNode()->begin()); + } /*! \return end iterator */ - iterator end() const { return iterator(GetArrayNode()->end()); } + iterator end() const { + return iterator(nullptr == GetArrayNode() ? static_cast(nullptr) + : GetArrayNode()->end()); + } /*! \return rbegin iterator */ reverse_iterator rbegin() const { diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0ed61177e65a..72500210cfe1 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -538,7 +538,10 @@ class ObjectRef { /*! \return the internal object pointer */ const Object* get() const { return data_.get(); } /*! \return the internal object pointer */ - const Object* operator->() const { return get(); } + const Object* operator->() const { + ICHECK(nullptr != get()) << "Calling `->` to nullptr"; + return get(); + } /*! \return whether the reference is unique */ bool unique() const { return data_.unique(); } /*! \return The use count of the ptr, for debug purposes */ @@ -703,12 +706,16 @@ struct ObjectPtrEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { return static_cast(data_.get()); } \ - const ObjectName* get() const { return operator->(); } \ +#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { \ + auto ptr = static_cast(data_.get()); \ + ICHECK(nullptr != ptr) << "Calling `->` to <" #ObjectName ">(null)"; \ + return ptr; \ + } \ + const ObjectName* get() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; /* @@ -738,7 +745,11 @@ struct ObjectPtrEqual { TypeName() = default; \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - ObjectName* operator->() const { return static_cast(data_.get()); } \ + ObjectName* operator->() const { \ + auto ptr = static_cast(data_.get()); \ + ICHECK(nullptr != ptr) << "Calling `->` to <" #ObjectName ">(null)"; \ + return ptr; \ + } \ using ContainerType = ObjectName; /* diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 65c5c12a701b..3e74783da06e 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -114,7 +114,11 @@ class Var : public PrimExpr { * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* operator->() const { return get(); } + const VarNode* operator->() const { + ICHECK(nullptr != get()) << "Calling `->` to (null)"; + return get(); + } + /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 90aaa35d60d8..2c69a042a563 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -106,6 +106,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { + ICHECK(nullptr != stmt.get()) + << "Cannot pass null statement to BlockReadWriteDetector::operator()"; const auto* block = stmt.as(); ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); for (const MatchBufferRegion& match_buffer : block->match_buffers) { diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index efffa9031ac0..86aa416fd261 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -242,6 +242,8 @@ class GPUCodeVerifier : public StmtExprVisitor { }; std::vector VerifyGPUCode_(const PrimFunc& func, Map constraints) { + ICHECK(nullptr != constraints.get()) << "Cannot pass null map to VerifyGPUCode_"; + GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index 66f3ef9e599f..6c9a135d8ac0 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -118,8 +118,71 @@ def test_exception(): x = tir.Var(name=1, dtype="int") +def test_nullptr_exception(): + test_candidates = [ + (tvm.tir.analysis.calculate_workspace_bytes, (None, None)), + (tvm.tir.analysis.detect_buffer_access_lca, (None,)), + (tvm.tir.analysis.get_block_access_region, (None, None)), + (tvm.tir.analysis.get_block_read_write_region, (None, None)), + (tvm.tir.analysis.verify_gpu_code, (None, None)), + (tvm.tir.analysis.verify_memory, (None,)), + (tvm.tir.analysis.verify_ssa, (None,)), + (tvm.tir.analysis.BufferRegion, (None, None)), + # tir.expr + (tvm.tir.expr.BufferLoad, (None, None, None)), + (tvm.tir.expr.Call, (None, None, None, None)), + (tvm.tir.expr.ProducerLoad, (None, None, None)), + # tir.generic + (tvm.tir.generic.add, (None, None, None)), + (tvm.tir.generic.cast, (None, None, None)), + (tvm.tir.generic.divide, (None, None, None)), + (tvm.tir.generic.floordiv, (None, None, None)), + (tvm.tir.generic.multiply, (None, None, None)), + (tvm.tir.generic.subtract, (None, None, None)), + # tir.op + (tvm.tir.op.abs, (None, None)), + (tvm.tir.op.ceil, (None, None)), + (tvm.tir.op.clz, (None,)), + (tvm.tir.op.div, (None, None, None)), + (tvm.tir.op.floor, (None, None)), + (tvm.tir.op.floordiv, (None, None, None)), + (tvm.tir.op.floormod, (None, None, None)), + (tvm.tir.op.if_then_else, (None, None, None, None)), + (tvm.tir.op.indexdiv, (None, None, None)), + (tvm.tir.op.indexmod, (None, None, None)), + (tvm.tir.op.isfinite, (None, None)), + (tvm.tir.op.isinf, (None, None)), + (tvm.tir.op.isnan, (None, None)), + (tvm.tir.op.nearbyint, (None, None)), + (tvm.tir.op.power, (None, None, None)), + (tvm.tir.op.q_multiply_shift, (None, None, None, None)), + (tvm.tir.op.round, (None, None)), + (tvm.tir.op.trunc, (None, None)), + (tvm.tir.op.truncdiv, (None, None, None)), + (tvm.tir.op.truncmod, (None, None, None)), + (tvm.tir.op.Call, (None, None, None, None)), + # tir.stmt_functor + (tvm.tir.stmt_functor.ir_transform, (None, None, None, None)), + (tvm.tir.stmt_functor.post_order_visit, (None, None)), + (tvm.tir.stmt_functor.substitute, (None, None)), + # tvm.stmt + (tvm.tir.stmt.Allocate, (None, None, None, None, None, None)), + (tvm.tir.stmt.BlockRealize, (None, None, None, None)), + (tvm.tir.stmt.BufferRegion, (None, None)), + (tvm.tir.stmt.MatchBufferRegion, (None, None)), + ] + + for func, args in test_candidates: + try: + print(func.__name__) + func(*args) + except TVMError as _: + pass + + if __name__ == "__main__": test_scalar_add() test_ret_const() test_control_flow_jump() test_exception() + test_nullptr_exception()