Skip to content

Commit

Permalink
[REFACTOR][FFI] Make more clear naming for C API Type codes.
Browse files Browse the repository at this point in the history
This PR introduces more clear naming prefix for C API type codes
to avoid conflict with other packages.

We also removed TVMArray and TVMType to directly use DLTensor and DLDataType.
  • Loading branch information
tqchen committed Jan 15, 2020
1 parent 49d3144 commit 09d4444
Show file tree
Hide file tree
Showing 81 changed files with 580 additions and 587 deletions.
2 changes: 1 addition & 1 deletion apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int additional_info = args[0];
*rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer);
CHECK_EQ(rv->type_code(), kTVMNDArrayHandle);

});

Expand Down
18 changes: 9 additions & 9 deletions golang/src/ndarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (parray Array) nativeCPtr() (retVal uintptr) {
}

func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) {
ret := C.TVMArrayCopyFromBytes((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr())),
ret := C.TVMArrayCopyFromBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())),
data,
C.ulong(datalen))
if ret != 0 {
Expand All @@ -65,7 +65,7 @@ func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error)
func (parray Array) CopyFrom(val interface{}) (err error) {
var data unsafe.Pointer
var datalen int
dtype := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype
dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype

switch val.(type) {
case []int8:
Expand Down Expand Up @@ -126,7 +126,7 @@ func (parray Array) CopyFrom(val interface{}) (err error) {
}

func (parray Array) nativeCopyTo (data unsafe.Pointer, datalen int) (err error){
ret := C.TVMArrayCopyToBytes((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr())),
ret := C.TVMArrayCopyToBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())),
unsafe.Pointer(data),
C.ulong(datalen))

Expand All @@ -149,7 +149,7 @@ func (parray Array) AsSlice() (retVal interface{}, err error) {
for ii := range shape {
size *= shape[ii]
}
dtype := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype
dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype

switch parray.GetDType() {
case "int8":
Expand Down Expand Up @@ -221,13 +221,13 @@ func (parray Array) AsSlice() (retVal interface{}, err error) {

// GetNdim returns the number of dimentions in Array
func (parray Array) GetNdim() (retVal int32) {
retVal = int32(((*C.TVMArray)(unsafe.Pointer(parray))).ndim)
retVal = int32(((*C.DLTensor)(unsafe.Pointer(parray))).ndim)
return
}

// GetShape returns the number of dimentions in Array
func (parray Array) GetShape() (retVal []int64) {
shapePtr := (*C.int64_t)(((*C.TVMArray)(unsafe.Pointer(parray))).shape)
shapePtr := (*C.int64_t)(((*C.DLTensor)(unsafe.Pointer(parray))).shape)
ndim := parray.GetNdim()

shapeSlice := (*[1<<31] int64)(unsafe.Pointer(shapePtr))[:ndim:ndim]
Expand All @@ -238,14 +238,14 @@ func (parray Array) GetShape() (retVal []int64) {

// GetDType returns the number of dimentions in Array
func (parray Array) GetDType() (retVal string) {
ret := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype
ret := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype
retVal, _ = dtypeFromTVMType(*(*pTVMType)(unsafe.Pointer(&ret)))
return
}

// GetCtx returns the number of dimentions in Array
func (parray Array) GetCtx() (retVal Context) {
ret := ((*C.TVMArray)(unsafe.Pointer(parray))).ctx
ret := ((*C.DLTensor)(unsafe.Pointer(parray))).ctx
retVal = *(*Context)(unsafe.Pointer(&ret))
return
}
Expand Down Expand Up @@ -342,6 +342,6 @@ func Empty(shape []int64, args ...interface{}) (parray *Array, err error) {
//
// `ret` indicates the status of this api execution.
func nativeTVMArrayFree(parray Array) (retVal int32) {
retVal = (int32)(C.TVMArrayFree((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr()))))
retVal = (int32)(C.TVMArrayFree((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr()))))
return
}
52 changes: 26 additions & 26 deletions golang/src/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,38 @@ import (
"unsafe"
)

// KHandle is golang type code for TVM enum kHandle.
var KHandle = int32(C.kHandle)
// KNull is golang type code for TVM kNull.
var KNull = int32(C.kNull)
// KTVMType is golang type code for TVM kTVMType.
var KTVMType = int32(C.kTVMType)
// KHandle is golang type code for TVM enum kTVMOpaqueHandle.
var KHandle = int32(C.kTVMOpaqueHandle)
// KNull is golang type code for TVM kTVMNullptr.
var KNull = int32(C.kTVMNullptr)
// KTVMType is golang type code for TVM kTVMDataType.
var KTVMType = int32(C.kTVMDataType)
// KTVMContext is golang type code for TVM kTVMContext.
var KTVMContext = int32(C.kTVMContext)
// KArrayHandle is golang type code for TVM kArrayHandle.
var KArrayHandle = int32(C.kArrayHandle)
// KObjectHandle is golang type code for TVM kObjectHandle.
var KObjectHandle = int32(C.kObjectHandle)
// KModuleHandle is gonag type code for TVM kModuleHandle.
var KModuleHandle = int32(C.kModuleHandle)
// KFuncHandle is gonalg type code for TVM kFuncHandle.
var KFuncHandle = int32(C.kFuncHandle)
// KStr is golang type code for TVM kStr.
var KStr = int32(C.kStr)
// KBytes is golang type code for TVM kBytes.
var KBytes = int32(C.kBytes)
// KNDArrayContainer is golang typecode for kNDArrayContainer.
var KNDArrayContainer = int32(C.kNDArrayContainer)
// KExtBegin is golang enum corresponding to TVM kExtBegin.
var KExtBegin = int32(C.kExtBegin)
// KArrayHandle is golang type code for TVM kTVMDLTensorHandle.
var KArrayHandle = int32(C.kTVMDLTensorHandle)
// KObjectHandle is golang type code for TVM kTVMObjectHandle.
var KObjectHandle = int32(C.kTVMObjectHandle)
// KModuleHandle is gonag type code for TVM kTVMModuleHandle.
var KModuleHandle = int32(C.kTVMModuleHandle)
// KFuncHandle is gonalg type code for TVM kTVMPackedFuncHandle.
var KFuncHandle = int32(C.kTVMPackedFuncHandle)
// KStr is golang type code for TVM kTVMStr.
var KStr = int32(C.kTVMStr)
// KBytes is golang type code for TVM kTVMBytes.
var KBytes = int32(C.kTVMBytes)
// KNDArrayContainer is golang typecode for kTVMNDArrayHandle.
var KNDArrayContainer = int32(C.kTVMNDArrayHandle)
// KExtBegin is golang enum corresponding to TVM kTVMExtBegin.
var KExtBegin = int32(C.kTVMExtBegin)
// KNNVMFirst is golang enum corresponding to TVM kNNVMFirst.
var KNNVMFirst = int32(C.kNNVMFirst)
var KNNVMFirst = int32(C.kTVMNNVMFirst)
// KNNVMLast is golang enum corresponding to TVM kNNVMLast.
var KNNVMLast = int32(C.kNNVMLast)
var KNNVMLast = int32(C.kTVMNNVMLast)
// KExtReserveEnd is golang enum corresponding to TVM kExtReserveEnd.
var KExtReserveEnd = int32(C.kExtReserveEnd)
var KExtReserveEnd = int32(C.kTVMExtReserveEnd)
// KExtEnd is golang enum corresponding to TVM kExtEnd.
var KExtEnd = int32(C.kExtEnd)
var KExtEnd = int32(C.kTVMExtEnd)
// KDLInt is golang type code for TVM kDLInt.
var KDLInt = int32(C.kDLInt)
// KDLUInt is golang type code for TVM kDLUInt.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kTVMCustomBegin)) {
return FloatImm(t, static_cast<double>(value));
}
LOG(FATAL) << "cannot make const for type " << t;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ inline TObjectRef NullValue() {

template<>
inline DataType NullValue<DataType>() {
return DataType(kHandle, 0, 0);
return DataType(DataType::kHandle, 0, 0);
}

/*! \brief Error thrown during attribute checking. */
Expand Down Expand Up @@ -492,7 +492,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {

template<>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) {
if (val.type_code() == kTVMStr) {
*ptr = val.operator std::string();
} else {
LOG(FATAL) << "Expect str";
Expand Down Expand Up @@ -762,7 +762,7 @@ class AttrsNode : public BaseAttrsNode {
// linear search.
auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
for (int i = 0; i < args.size(); i += 2) {
CHECK_EQ(args.type_codes[i], kStr);
CHECK_EQ(args.type_codes[i], kTVMStr);
if (!std::strcmp(key, args.values[i].v_str)) {
*val = args[i + 1];
return true;
Expand All @@ -777,7 +777,7 @@ class AttrsNode : public BaseAttrsNode {
// construct a map then do lookup.
std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
for (int i = 0; i < args.size(); i += 2) {
CHECK_EQ(args.type_codes[i], kStr);
CHECK_EQ(args.type_codes[i], kTVMStr);
kwargs[args[i].operator std::string()] = args[i + 1];
}
auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) {
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ struct ObjectTypeChecker<Map<K, V> > {

// extensions for tvm arg value
inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kNull) return PrimExpr();
if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
Expand All @@ -110,7 +110,7 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
return PrimExpr(static_cast<float>(value_.v_float64));
}

TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);

if (ptr->IsInstance<IterVarNode>()) {
Expand All @@ -126,13 +126,13 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
}

inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kTVMNullptr) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
Expand Down
58 changes: 20 additions & 38 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,62 +86,44 @@ typedef enum {
} TVMDeviceExtType;

/*!
* \brief The type code in TVMType
* \note TVMType is used in two places.
* \brief The type code in used in the TVM FFI.
*/
typedef enum {
// The type code of other types are compatible with DLPack.
// The next few fields are extension types
// that is used by TVM API calls.
kHandle = 3U,
kNull = 4U,
kTVMType = 5U,
kTVMOpaqueHandle = 3U,
kTVMNullptr = 4U,
kTVMDataType = 5U,
kTVMContext = 6U,
kArrayHandle = 7U,
kObjectHandle = 8U,
kModuleHandle = 9U,
kFuncHandle = 10U,
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kTVMDLTensorHandle = 7U,
kTVMObjectHandle = 8U,
kTVMModuleHandle = 9U,
kTVMPackedFuncHandle = 10U,
kTVMStr = 11U,
kTVMBytes = 12U,
kTVMNDArrayHandle = 13U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kExtBegin = 15U,
kNNVMFirst = 16U,
kNNVMLast = 20U,
kTVMExtBegin = 15U,
kTVMNNVMFirst = 16U,
kTVMNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U,
kTVMExtReserveEnd = 64U,
kTVMExtEnd = 128U,
// The rest of the space is used for custom, user-supplied datatypes
kCustomBegin = 129U,
kTVMCustomBegin = 129U,
} TVMTypeCode;

/*!
* \brief The data type used in TVM Runtime.
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments TVM API function always takes bits=64 and lanes=1
*/
typedef DLDataType TVMType;

/*!
* \brief The Device information, abstract away common device types.
*/
typedef DLContext TVMContext;

/*!
* \brief The tensor array structure to TVM API.
*/
typedef DLTensor TVMArray;

/*! \brief the array handle */
typedef TVMArray* TVMArrayHandle;
typedef DLTensor* TVMArrayHandle;

/*!
* \brief Union type of values
Expand All @@ -152,13 +134,13 @@ typedef union {
double v_float64;
void* v_handle;
const char* v_str;
TVMType v_type;
DLDataType v_type;
TVMContext v_ctx;
} TVMValue;

/*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
* When kTVMBytes is used as data type.
*/
typedef struct {
const char* data;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DataType {
kInt = kDLInt,
kUInt = kDLUInt,
kFloat = kDLFloat,
kHandle = TVMTypeCode::kHandle,
kHandle = TVMTypeCode::kTVMOpaqueHandle,
};
/*! \brief default constructor */
DataType() {}
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TVM_DLL DeviceAPI {
virtual void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) = 0;
DLDataType type_hint) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
Expand All @@ -115,7 +115,7 @@ class TVM_DLL DeviceAPI {
size_t num_bytes,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
DLDataType type_hint,
TVMStreamHandle stream) = 0;
/*!
* \brief Create a new stream of execution.
Expand Down Expand Up @@ -177,7 +177,7 @@ class TVM_DLL DeviceAPI {
*/
virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
TVMType type_hint = {});
DLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
Expand Down
Loading

0 comments on commit 09d4444

Please sign in to comment.