diff --git a/include/tvm/runtime/ndarray.h b/include/tvm/runtime/ndarray.h index 0fc8e42b8bcb..e2a447e4235c 100644 --- a/include/tvm/runtime/ndarray.h +++ b/include/tvm/runtime/ndarray.h @@ -21,7 +21,7 @@ namespace runtime { class NDArray { public: // internal container type - struct Container; + class Container; /*! \brief default constructor */ NDArray() {} /*! @@ -173,7 +173,7 @@ class NDArray { // internal namespace struct Internal; - private: + protected: /*! \brief Internal Data content */ Container* data_{nullptr}; // enable internal functions @@ -198,7 +198,7 @@ inline bool SaveDLTensor(dmlc::Stream* strm, const DLTensor* tensor); * * \note: do not use this function directly, use NDArray. */ -struct NDArray::Container { +class NDArray::Container { public: // NOTE: the first part of this structure is the same as // DLManagedTensor, note that, however, the deleter @@ -225,6 +225,28 @@ struct NDArray::Container { * currently defined by the system. */ void (*deleter)(Container* self) = nullptr; + + protected: + friend class NDArray; + friend class RPCWrappedFunc; + /*! + * \brief Type flag used to indicate subclass. + * Default value 0 means normal NDArray::Conatainer. + * + * We can extend a more specialized NDArray::Container + * and use the array_type_index_ to indicate + * the specific array subclass. + */ + uint32_t array_type_index_{0}; + /*! \brief The internal reference counter */ + std::atomic ref_counter_{0}; + /*! + * \brief The shape container, + * can be used used for shape data. + */ + std::vector shape_; + + public: /*! \brief default constructor */ Container() { dl_tensor.data = nullptr; @@ -246,17 +268,6 @@ struct NDArray::Container { } } } - - private: - friend class NDArray; - friend class RPCWrappedFunc; - /*! - * \brief The shape container, - * can be used used for shape data. - */ - std::vector shape_; - /*! \brief The internal array object */ - std::atomic ref_counter_{0}; }; // implementations of inline functions