From 47203c3dff8d2d9f152b46477b61419efc6112ec Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:50:06 -0500 Subject: [PATCH 1/9] fix: check nullptr before calling `->` --- include/tvm/runtime/object.h | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index 0ed61177e65a..b1fdbd08deab 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -538,7 +538,7 @@ class ObjectRef { /*! \return the internal object pointer */ const Object* get() const { return data_.get(); } /*! \return the internal object pointer */ - const Object* operator->() const { return get(); } + const Object* operator->() const { ICHECK(nullptr != get()) << "Calling `->` to nullptr"; return get(); } /*! \return whether the reference is unique */ bool unique() const { return data_.unique(); } /*! \return The use count of the ptr, for debug purposes */ @@ -707,8 +707,12 @@ struct ObjectPtrEqual { TypeName() = default; \ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { return static_cast(data_.get()); } \ - const ObjectName* get() const { return operator->(); } \ + const ObjectName* operator->() const { \ + auto ptr = static_cast(data_.get()); \ + ICHECK(nullptr != ptr) << "Calling `->` to <" #ObjectName ">(null)"; \ + return ptr; \ + } \ + const ObjectName* get() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; /* @@ -738,7 +742,11 @@ struct ObjectPtrEqual { TypeName() = default; \ TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - ObjectName* operator->() const { return static_cast(data_.get()); } \ + ObjectName* operator->() const { \ + auto ptr = static_cast(data_.get()); \ + ICHECK(nullptr != ptr) << "Calling `->` to <" #ObjectName ">(null)"; \ + return ptr; \ + } \ using ContainerType = ObjectName; /* From d27f8c9f324f63e450e80fd569d8fb4c4249d320 Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:50:48 -0500 Subject: [PATCH 2/9] fix: check nullptr before calling `->` in PrimExpr --- include/tvm/ir/expr.h | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/include/tvm/ir/expr.h b/include/tvm/ir/expr.h index b910d32ceca4..e6753cf8a3bd 100644 --- a/include/tvm/ir/expr.h +++ b/include/tvm/ir/expr.h @@ -119,11 +119,11 @@ class PrimExpr : public BaseExpr { */ TVM_DLL PrimExpr(float value); // NOLINT(*) - /*! \return the data type of this expression. */ - DataType dtype() const { return static_cast(get())->dtype; } - TVM_DEFINE_OBJECT_REF_METHODS(PrimExpr, BaseExpr, PrimExprNode); + /*! \return the data type of this expression. */ + DataType dtype() const { return operator->()->dtype; } + private: // Internal function for conversion. friend struct runtime::PackedFuncValueConverter; From cb9ebbbb7806c913dff5b8936d8adfc7677787e9 Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:51:01 -0500 Subject: [PATCH 3/9] fix: check nullptr before calling `->` in Var --- include/tvm/tir/var.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 65c5c12a701b..138c17cfd6d5 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -114,7 +114,8 @@ class Var : public PrimExpr { * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* operator->() const { return get(); } + const VarNode* operator->() const { ICHECK(nullptr != get()) << "Calling `->` to (null)"; return get(); } + /*! * \brief Get pointer to the internal value. * \return the corresponding Variable. From 9c6e120c06565b6cba89eff364190cf4e181214f Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:51:55 -0500 Subject: [PATCH 4/9] fix: allow range-for for empty null array --- include/tvm/runtime/container/array.h | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index 8830653da88c..f4bce8419fb2 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -363,10 +363,14 @@ class Array : public ObjectRef { using reverse_iterator = ReverseIterAdapter; /*! \return begin iterator */ - iterator begin() const { return iterator(GetArrayNode()->begin()); } + iterator begin() const { + return iterator(nullptr == GetArrayNode() ? static_cast(nullptr) : GetArrayNode()->begin()); + } /*! \return end iterator */ - iterator end() const { return iterator(GetArrayNode()->end()); } + iterator end() const { + return iterator(nullptr == GetArrayNode() ? static_cast(nullptr) : GetArrayNode()->end()); + } /*! \return rbegin iterator */ reverse_iterator rbegin() const { From 8852aa5b13d44e04eca87523cd0a04d7ec40aa89 Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:53:12 -0500 Subject: [PATCH 5/9] fix: check if statement is null to avoid twice exception --- src/tir/analysis/block_access_region_detector.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 90aaa35d60d8..dc48000b82a8 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -106,6 +106,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { + ICHECK(nullptr != stmt.get()) << "Cannot pass null statement to BlockReadWriteDetector::operator()"; const auto* block = stmt.as(); ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); for (const MatchBufferRegion& match_buffer : block->match_buffers) { From fb088f5715bcd233fdd3eaad2b790339519cf7f1 Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:53:39 -0500 Subject: [PATCH 6/9] fix: check map validity --- src/tir/analysis/verify_gpu_code.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index efffa9031ac0..86aa416fd261 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -242,6 +242,8 @@ class GPUCodeVerifier : public StmtExprVisitor { }; std::vector VerifyGPUCode_(const PrimFunc& func, Map constraints) { + ICHECK(nullptr != constraints.get()) << "Cannot pass null map to VerifyGPUCode_"; + GPUCodeVerifier verifier; int64_t max_local_memory_per_block = INT64_MAX; From 7c8779058f787d99253b0a7f0a85ba08f2e9a894 Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:54:06 -0500 Subject: [PATCH 7/9] add test to test tricky inputs in Python interface --- tests/python/unittest/test_tir_base.py | 68 ++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index 66f3ef9e599f..c838564298d5 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -117,9 +117,77 @@ def test_exception(): with pytest.raises(tvm.TVMError): x = tir.Var(name=1, dtype="int") +def test_nullptr_exception(): + test_candidates = [ + (tvm.tir.analysis.calculate_workspace_bytes, (None, None)), + (tvm.tir.analysis.detect_buffer_access_lca, (None, )), + (tvm.tir.analysis.get_block_access_region, (None, None)), + (tvm.tir.analysis.get_block_read_write_region, (None, None)), + (tvm.tir.analysis.verify_gpu_code, (None, None)), + (tvm.tir.analysis.verify_memory, (None, )), + (tvm.tir.analysis.verify_ssa, (None, )), + (tvm.tir.analysis.BufferRegion, (None, None)), + + # tir.expr + (tvm.tir.expr.BufferLoad, (None, None, None)), + (tvm.tir.expr.Call, (None, None, None, None)), + (tvm.tir.expr.ProducerLoad, (None, None, None)), + + # tir.generic + (tvm.tir.generic.add, (None, None, None)), + (tvm.tir.generic.cast, (None, None, None)), + (tvm.tir.generic.divide, (None, None, None)), + (tvm.tir.generic.floordiv, (None, None, None)), + (tvm.tir.generic.multiply, (None, None, None)), + (tvm.tir.generic.subtract, (None, None, None)), + + # tir.op + (tvm.tir.op.abs, (None, None)), + (tvm.tir.op.ceil, (None, None)), + (tvm.tir.op.clz, (None, )), + (tvm.tir.op.div, (None, None, None)), + (tvm.tir.op.floor, (None, None)), + (tvm.tir.op.floordiv, (None, None, None)), + (tvm.tir.op.floormod, (None, None, None)), + (tvm.tir.op.if_then_else, (None, None, None, None)), + (tvm.tir.op.indexdiv, (None, None, None)), + (tvm.tir.op.indexmod, (None, None, None)), + (tvm.tir.op.isfinite, (None, None)), + (tvm.tir.op.isinf, (None, None)), + (tvm.tir.op.isnan, (None, None)), + (tvm.tir.op.nearbyint, (None, None)), + (tvm.tir.op.power, (None, None, None)), + (tvm.tir.op.q_multiply_shift, (None, None, None, None)), + (tvm.tir.op.round, (None, None)), + (tvm.tir.op.trunc, (None, None)), + (tvm.tir.op.truncdiv, (None, None, None)), + (tvm.tir.op.truncmod, (None, None, None)), + (tvm.tir.op.Call, (None, None, None, None)), + + # tir.stmt_functor + (tvm.tir.stmt_functor.ir_transform, (None, None, None, None)), + (tvm.tir.stmt_functor.post_order_visit, (None, None)), + (tvm.tir.stmt_functor.substitute, (None, None)), + + # tvm.stmt + (tvm.tir.stmt.Allocate, (None, None, None, None, None, None)), + (tvm.tir.stmt.BlockRealize, (None, None, None, None)), + (tvm.tir.stmt.BufferRegion, (None, None)), + (tvm.tir.stmt.MatchBufferRegion, (None, None)), + ] + + for func, args in test_candidates: + try: + print(func.__name__) + func(*args) + except TVMError as _: + pass + + if __name__ == "__main__": test_scalar_add() test_ret_const() test_control_flow_jump() test_exception() + test_nullptr_exception() From 54d05e09927edad054845309dbd69f0eca5e325c Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:54:42 -0500 Subject: [PATCH 8/9] refact: cpp format --- include/tvm/runtime/container/array.h | 10 +++++--- include/tvm/runtime/object.h | 25 +++++++++++-------- include/tvm/tir/var.h | 5 +++- .../analysis/block_access_region_detector.cc | 3 ++- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/include/tvm/runtime/container/array.h b/include/tvm/runtime/container/array.h index f4bce8419fb2..c6d2f5f0b998 100644 --- a/include/tvm/runtime/container/array.h +++ b/include/tvm/runtime/container/array.h @@ -363,13 +363,15 @@ class Array : public ObjectRef { using reverse_iterator = ReverseIterAdapter; /*! \return begin iterator */ - iterator begin() const { - return iterator(nullptr == GetArrayNode() ? static_cast(nullptr) : GetArrayNode()->begin()); + iterator begin() const { + return iterator(nullptr == GetArrayNode() ? static_cast(nullptr) + : GetArrayNode()->begin()); } /*! \return end iterator */ - iterator end() const { - return iterator(nullptr == GetArrayNode() ? static_cast(nullptr) : GetArrayNode()->end()); + iterator end() const { + return iterator(nullptr == GetArrayNode() ? static_cast(nullptr) + : GetArrayNode()->end()); } /*! \return rbegin iterator */ diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index b1fdbd08deab..72500210cfe1 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -538,7 +538,10 @@ class ObjectRef { /*! \return the internal object pointer */ const Object* get() const { return data_.get(); } /*! \return the internal object pointer */ - const Object* operator->() const { ICHECK(nullptr != get()) << "Calling `->` to nullptr"; return get(); } + const Object* operator->() const { + ICHECK(nullptr != get()) << "Calling `->` to nullptr"; + return get(); + } /*! \return whether the reference is unique */ bool unique() const { return data_.unique(); } /*! \return The use count of the ptr, for debug purposes */ @@ -703,16 +706,16 @@ struct ObjectPtrEqual { * \param ParentType The parent type of the objectref * \param ObjectName The type name of the object. */ -#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ - TypeName() = default; \ - explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ - TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ - const ObjectName* operator->() const { \ - auto ptr = static_cast(data_.get()); \ +#define TVM_DEFINE_OBJECT_REF_METHODS(TypeName, ParentType, ObjectName) \ + TypeName() = default; \ + explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ + TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(TypeName); \ + const ObjectName* operator->() const { \ + auto ptr = static_cast(data_.get()); \ ICHECK(nullptr != ptr) << "Calling `->` to <" #ObjectName ">(null)"; \ - return ptr; \ - } \ - const ObjectName* get() const { return static_cast(data_.get()); } \ + return ptr; \ + } \ + const ObjectName* get() const { return static_cast(data_.get()); } \ using ContainerType = ObjectName; /* @@ -744,7 +747,7 @@ struct ObjectPtrEqual { explicit TypeName(::tvm::runtime::ObjectPtr<::tvm::runtime::Object> n) : ParentType(n) {} \ ObjectName* operator->() const { \ auto ptr = static_cast(data_.get()); \ - ICHECK(nullptr != ptr) << "Calling `->` to <" #ObjectName ">(null)"; \ + ICHECK(nullptr != ptr) << "Calling `->` to <" #ObjectName ">(null)"; \ return ptr; \ } \ using ContainerType = ObjectName; diff --git a/include/tvm/tir/var.h b/include/tvm/tir/var.h index 138c17cfd6d5..3e74783da06e 100644 --- a/include/tvm/tir/var.h +++ b/include/tvm/tir/var.h @@ -114,7 +114,10 @@ class Var : public PrimExpr { * \brief Get pointer to the internal value. * \return the corresponding Variable. */ - const VarNode* operator->() const { ICHECK(nullptr != get()) << "Calling `->` to (null)"; return get(); } + const VarNode* operator->() const { + ICHECK(nullptr != get()) << "Calling `->` to (null)"; + return get(); + } /*! * \brief Get pointer to the internal value. diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index dc48000b82a8..2c69a042a563 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -106,7 +106,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { - ICHECK(nullptr != stmt.get()) << "Cannot pass null statement to BlockReadWriteDetector::operator()"; + ICHECK(nullptr != stmt.get()) + << "Cannot pass null statement to BlockReadWriteDetector::operator()"; const auto* block = stmt.as(); ICHECK(block != nullptr) << "Only visiting Blocks is allowed, but got " << stmt->GetTypeKey(); for (const MatchBufferRegion& match_buffer : block->match_buffers) { From 63425b39f501d2f71d6cb3749c1158b5372d7e60 Mon Sep 17 00:00:00 2001 From: ganler Date: Fri, 10 Sep 2021 23:58:02 -0500 Subject: [PATCH 9/9] refact: python format --- tests/python/unittest/test_tir_base.py | 107 ++++++++++++------------- 1 file changed, 51 insertions(+), 56 deletions(-) diff --git a/tests/python/unittest/test_tir_base.py b/tests/python/unittest/test_tir_base.py index c838564298d5..6c9a135d8ac0 100644 --- a/tests/python/unittest/test_tir_base.py +++ b/tests/python/unittest/test_tir_base.py @@ -117,63 +117,59 @@ def test_exception(): with pytest.raises(tvm.TVMError): x = tir.Var(name=1, dtype="int") + def test_nullptr_exception(): test_candidates = [ - (tvm.tir.analysis.calculate_workspace_bytes, (None, None)), - (tvm.tir.analysis.detect_buffer_access_lca, (None, )), - (tvm.tir.analysis.get_block_access_region, (None, None)), - (tvm.tir.analysis.get_block_read_write_region, (None, None)), - (tvm.tir.analysis.verify_gpu_code, (None, None)), - (tvm.tir.analysis.verify_memory, (None, )), - (tvm.tir.analysis.verify_ssa, (None, )), - (tvm.tir.analysis.BufferRegion, (None, None)), - - # tir.expr - (tvm.tir.expr.BufferLoad, (None, None, None)), - (tvm.tir.expr.Call, (None, None, None, None)), - (tvm.tir.expr.ProducerLoad, (None, None, None)), - - # tir.generic - (tvm.tir.generic.add, (None, None, None)), - (tvm.tir.generic.cast, (None, None, None)), - (tvm.tir.generic.divide, (None, None, None)), - (tvm.tir.generic.floordiv, (None, None, None)), - (tvm.tir.generic.multiply, (None, None, None)), - (tvm.tir.generic.subtract, (None, None, None)), - - # tir.op - (tvm.tir.op.abs, (None, None)), - (tvm.tir.op.ceil, (None, None)), - (tvm.tir.op.clz, (None, )), - (tvm.tir.op.div, (None, None, None)), - (tvm.tir.op.floor, (None, None)), - (tvm.tir.op.floordiv, (None, None, None)), - (tvm.tir.op.floormod, (None, None, None)), - (tvm.tir.op.if_then_else, (None, None, None, None)), - (tvm.tir.op.indexdiv, (None, None, None)), - (tvm.tir.op.indexmod, (None, None, None)), - (tvm.tir.op.isfinite, (None, None)), - (tvm.tir.op.isinf, (None, None)), - (tvm.tir.op.isnan, (None, None)), - (tvm.tir.op.nearbyint, (None, None)), - (tvm.tir.op.power, (None, None, None)), - (tvm.tir.op.q_multiply_shift, (None, None, None, None)), - (tvm.tir.op.round, (None, None)), - (tvm.tir.op.trunc, (None, None)), - (tvm.tir.op.truncdiv, (None, None, None)), - (tvm.tir.op.truncmod, (None, None, None)), - (tvm.tir.op.Call, (None, None, None, None)), - - # tir.stmt_functor - (tvm.tir.stmt_functor.ir_transform, (None, None, None, None)), - (tvm.tir.stmt_functor.post_order_visit, (None, None)), - (tvm.tir.stmt_functor.substitute, (None, None)), - - # tvm.stmt - (tvm.tir.stmt.Allocate, (None, None, None, None, None, None)), - (tvm.tir.stmt.BlockRealize, (None, None, None, None)), - (tvm.tir.stmt.BufferRegion, (None, None)), - (tvm.tir.stmt.MatchBufferRegion, (None, None)), + (tvm.tir.analysis.calculate_workspace_bytes, (None, None)), + (tvm.tir.analysis.detect_buffer_access_lca, (None,)), + (tvm.tir.analysis.get_block_access_region, (None, None)), + (tvm.tir.analysis.get_block_read_write_region, (None, None)), + (tvm.tir.analysis.verify_gpu_code, (None, None)), + (tvm.tir.analysis.verify_memory, (None,)), + (tvm.tir.analysis.verify_ssa, (None,)), + (tvm.tir.analysis.BufferRegion, (None, None)), + # tir.expr + (tvm.tir.expr.BufferLoad, (None, None, None)), + (tvm.tir.expr.Call, (None, None, None, None)), + (tvm.tir.expr.ProducerLoad, (None, None, None)), + # tir.generic + (tvm.tir.generic.add, (None, None, None)), + (tvm.tir.generic.cast, (None, None, None)), + (tvm.tir.generic.divide, (None, None, None)), + (tvm.tir.generic.floordiv, (None, None, None)), + (tvm.tir.generic.multiply, (None, None, None)), + (tvm.tir.generic.subtract, (None, None, None)), + # tir.op + (tvm.tir.op.abs, (None, None)), + (tvm.tir.op.ceil, (None, None)), + (tvm.tir.op.clz, (None,)), + (tvm.tir.op.div, (None, None, None)), + (tvm.tir.op.floor, (None, None)), + (tvm.tir.op.floordiv, (None, None, None)), + (tvm.tir.op.floormod, (None, None, None)), + (tvm.tir.op.if_then_else, (None, None, None, None)), + (tvm.tir.op.indexdiv, (None, None, None)), + (tvm.tir.op.indexmod, (None, None, None)), + (tvm.tir.op.isfinite, (None, None)), + (tvm.tir.op.isinf, (None, None)), + (tvm.tir.op.isnan, (None, None)), + (tvm.tir.op.nearbyint, (None, None)), + (tvm.tir.op.power, (None, None, None)), + (tvm.tir.op.q_multiply_shift, (None, None, None, None)), + (tvm.tir.op.round, (None, None)), + (tvm.tir.op.trunc, (None, None)), + (tvm.tir.op.truncdiv, (None, None, None)), + (tvm.tir.op.truncmod, (None, None, None)), + (tvm.tir.op.Call, (None, None, None, None)), + # tir.stmt_functor + (tvm.tir.stmt_functor.ir_transform, (None, None, None, None)), + (tvm.tir.stmt_functor.post_order_visit, (None, None)), + (tvm.tir.stmt_functor.substitute, (None, None)), + # tvm.stmt + (tvm.tir.stmt.Allocate, (None, None, None, None, None, None)), + (tvm.tir.stmt.BlockRealize, (None, None, None, None)), + (tvm.tir.stmt.BufferRegion, (None, None)), + (tvm.tir.stmt.MatchBufferRegion, (None, None)), ] for func, args in test_candidates: @@ -182,7 +178,6 @@ def test_nullptr_exception(): func(*args) except TVMError as _: pass - if __name__ == "__main__":