Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR][FFI] Make more clear naming for C API Type codes. #4715

Merged
merged 1 commit into from
Jan 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/extension/src/tvm_ext.cc
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ TVM_REGISTER_GLOBAL("tvm_ext.nd_create")
.set_body([](TVMArgs args, TVMRetValue *rv) {
int additional_info = args[0];
*rv = NDSubClass(additional_info);
CHECK_EQ(rv->type_code(), kNDArrayContainer);
CHECK_EQ(rv->type_code(), kTVMNDArrayHandle);

});

Expand Down
18 changes: 9 additions & 9 deletions golang/src/ndarray.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func (parray Array) nativeCPtr() (retVal uintptr) {
}

func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error) {
ret := C.TVMArrayCopyFromBytes((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr())),
ret := C.TVMArrayCopyFromBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())),
data,
C.ulong(datalen))
if ret != 0 {
Expand All @@ -65,7 +65,7 @@ func (parray Array) nativeCopyFrom(data unsafe.Pointer, datalen int) (err error)
func (parray Array) CopyFrom(val interface{}) (err error) {
var data unsafe.Pointer
var datalen int
dtype := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype
dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype

switch val.(type) {
case []int8:
Expand Down Expand Up @@ -126,7 +126,7 @@ func (parray Array) CopyFrom(val interface{}) (err error) {
}

func (parray Array) nativeCopyTo (data unsafe.Pointer, datalen int) (err error){
ret := C.TVMArrayCopyToBytes((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr())),
ret := C.TVMArrayCopyToBytes((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr())),
unsafe.Pointer(data),
C.ulong(datalen))

Expand All @@ -149,7 +149,7 @@ func (parray Array) AsSlice() (retVal interface{}, err error) {
for ii := range shape {
size *= shape[ii]
}
dtype := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype
dtype := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype

switch parray.GetDType() {
case "int8":
Expand Down Expand Up @@ -221,13 +221,13 @@ func (parray Array) AsSlice() (retVal interface{}, err error) {

// GetNdim returns the number of dimentions in Array
func (parray Array) GetNdim() (retVal int32) {
retVal = int32(((*C.TVMArray)(unsafe.Pointer(parray))).ndim)
retVal = int32(((*C.DLTensor)(unsafe.Pointer(parray))).ndim)
return
}

// GetShape returns the number of dimentions in Array
func (parray Array) GetShape() (retVal []int64) {
shapePtr := (*C.int64_t)(((*C.TVMArray)(unsafe.Pointer(parray))).shape)
shapePtr := (*C.int64_t)(((*C.DLTensor)(unsafe.Pointer(parray))).shape)
ndim := parray.GetNdim()

shapeSlice := (*[1<<31] int64)(unsafe.Pointer(shapePtr))[:ndim:ndim]
Expand All @@ -238,14 +238,14 @@ func (parray Array) GetShape() (retVal []int64) {

// GetDType returns the number of dimentions in Array
func (parray Array) GetDType() (retVal string) {
ret := ((*C.TVMArray)(unsafe.Pointer(parray))).dtype
ret := ((*C.DLTensor)(unsafe.Pointer(parray))).dtype
retVal, _ = dtypeFromTVMType(*(*pTVMType)(unsafe.Pointer(&ret)))
return
}

// GetCtx returns the number of dimentions in Array
func (parray Array) GetCtx() (retVal Context) {
ret := ((*C.TVMArray)(unsafe.Pointer(parray))).ctx
ret := ((*C.DLTensor)(unsafe.Pointer(parray))).ctx
retVal = *(*Context)(unsafe.Pointer(&ret))
return
}
Expand Down Expand Up @@ -342,6 +342,6 @@ func Empty(shape []int64, args ...interface{}) (parray *Array, err error) {
//
// `ret` indicates the status of this api execution.
func nativeTVMArrayFree(parray Array) (retVal int32) {
retVal = (int32)(C.TVMArrayFree((*C.TVMArray)(unsafe.Pointer(parray.nativeCPtr()))))
retVal = (int32)(C.TVMArrayFree((*C.DLTensor)(unsafe.Pointer(parray.nativeCPtr()))))
return
}
52 changes: 26 additions & 26 deletions golang/src/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,38 +33,38 @@ import (
"unsafe"
)

// KHandle is golang type code for TVM enum kHandle.
var KHandle = int32(C.kHandle)
// KNull is golang type code for TVM kNull.
var KNull = int32(C.kNull)
// KTVMType is golang type code for TVM kTVMType.
var KTVMType = int32(C.kTVMType)
// KHandle is golang type code for TVM enum kTVMOpaqueHandle.
var KHandle = int32(C.kTVMOpaqueHandle)
// KNull is golang type code for TVM kTVMNullptr.
var KNull = int32(C.kTVMNullptr)
// KTVMType is golang type code for TVM kTVMDataType.
var KTVMType = int32(C.kTVMDataType)
// KTVMContext is golang type code for TVM kTVMContext.
var KTVMContext = int32(C.kTVMContext)
// KArrayHandle is golang type code for TVM kArrayHandle.
var KArrayHandle = int32(C.kArrayHandle)
// KObjectHandle is golang type code for TVM kObjectHandle.
var KObjectHandle = int32(C.kObjectHandle)
// KModuleHandle is gonag type code for TVM kModuleHandle.
var KModuleHandle = int32(C.kModuleHandle)
// KFuncHandle is gonalg type code for TVM kFuncHandle.
var KFuncHandle = int32(C.kFuncHandle)
// KStr is golang type code for TVM kStr.
var KStr = int32(C.kStr)
// KBytes is golang type code for TVM kBytes.
var KBytes = int32(C.kBytes)
// KNDArrayContainer is golang typecode for kNDArrayContainer.
var KNDArrayContainer = int32(C.kNDArrayContainer)
// KExtBegin is golang enum corresponding to TVM kExtBegin.
var KExtBegin = int32(C.kExtBegin)
// KArrayHandle is golang type code for TVM kTVMDLTensorHandle.
var KArrayHandle = int32(C.kTVMDLTensorHandle)
// KObjectHandle is golang type code for TVM kTVMObjectHandle.
var KObjectHandle = int32(C.kTVMObjectHandle)
// KModuleHandle is gonag type code for TVM kTVMModuleHandle.
var KModuleHandle = int32(C.kTVMModuleHandle)
// KFuncHandle is gonalg type code for TVM kTVMPackedFuncHandle.
var KFuncHandle = int32(C.kTVMPackedFuncHandle)
// KStr is golang type code for TVM kTVMStr.
var KStr = int32(C.kTVMStr)
// KBytes is golang type code for TVM kTVMBytes.
var KBytes = int32(C.kTVMBytes)
// KNDArrayContainer is golang typecode for kTVMNDArrayHandle.
var KNDArrayContainer = int32(C.kTVMNDArrayHandle)
// KExtBegin is golang enum corresponding to TVM kTVMExtBegin.
var KExtBegin = int32(C.kTVMExtBegin)
// KNNVMFirst is golang enum corresponding to TVM kNNVMFirst.
var KNNVMFirst = int32(C.kNNVMFirst)
var KNNVMFirst = int32(C.kTVMNNVMFirst)
// KNNVMLast is golang enum corresponding to TVM kNNVMLast.
var KNNVMLast = int32(C.kNNVMLast)
var KNNVMLast = int32(C.kTVMNNVMLast)
// KExtReserveEnd is golang enum corresponding to TVM kExtReserveEnd.
var KExtReserveEnd = int32(C.kExtReserveEnd)
var KExtReserveEnd = int32(C.kTVMExtReserveEnd)
// KExtEnd is golang enum corresponding to TVM kExtEnd.
var KExtEnd = int32(C.kExtEnd)
var KExtEnd = int32(C.kTVMExtEnd)
// KDLInt is golang type code for TVM kDLInt.
var KDLInt = int32(C.kDLInt)
// KDLUInt is golang type code for TVM kDLUInt.
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/expr_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -682,7 +682,7 @@ inline PrimExpr MakeConstScalar(DataType t, ValueType value) {
// datatypes lowering pass, we will lower the value to its true representation in the format
// specified by the datatype.
// TODO(gus) when do we need to start worrying about doubles not being precise enough?
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kCustomBegin)) {
if (static_cast<uint8_t>(t.code()) >= static_cast<uint8_t>(kTVMCustomBegin)) {
return FloatImm(t, static_cast<double>(value));
}
LOG(FATAL) << "cannot make const for type " << t;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ inline TObjectRef NullValue() {

template<>
inline DataType NullValue<DataType>() {
return DataType(kHandle, 0, 0);
return DataType(DataType::kHandle, 0, 0);
}

/*! \brief Error thrown during attribute checking. */
Expand Down Expand Up @@ -492,7 +492,7 @@ inline void SetIntValue(T* ptr, const TVMArgValue& val) {

template<>
inline void SetValue<std::string>(std::string* ptr, const TVMArgValue& val) {
if (val.type_code() == kStr) {
if (val.type_code() == kTVMStr) {
*ptr = val.operator std::string();
} else {
LOG(FATAL) << "Expect str";
Expand Down Expand Up @@ -762,7 +762,7 @@ class AttrsNode : public BaseAttrsNode {
// linear search.
auto ffind = [&args](const char* key, runtime::TVMArgValue* val) {
for (int i = 0; i < args.size(); i += 2) {
CHECK_EQ(args.type_codes[i], kStr);
CHECK_EQ(args.type_codes[i], kTVMStr);
if (!std::strcmp(key, args.values[i].v_str)) {
*val = args[i + 1];
return true;
Expand All @@ -777,7 +777,7 @@ class AttrsNode : public BaseAttrsNode {
// construct a map then do lookup.
std::unordered_map<std::string, runtime::TVMArgValue> kwargs;
for (int i = 0; i < args.size(); i += 2) {
CHECK_EQ(args.type_codes[i], kStr);
CHECK_EQ(args.type_codes[i], kTVMStr);
kwargs[args[i].operator std::string()] = args[i + 1];
}
auto ffind = [&kwargs](const char *key, runtime::TVMArgValue* val) {
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ struct ObjectTypeChecker<Map<K, V> > {

// extensions for tvm arg value
inline TVMPODValue_::operator tvm::PrimExpr() const {
if (type_code_ == kNull) return PrimExpr();
if (type_code_ == kTVMNullptr) return PrimExpr();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
Expand All @@ -110,7 +110,7 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
return PrimExpr(static_cast<float>(value_.v_float64));
}

TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);

if (ptr->IsInstance<IterVarNode>()) {
Expand All @@ -126,13 +126,13 @@ inline TVMPODValue_::operator tvm::PrimExpr() const {
}

inline TVMPODValue_::operator tvm::Integer() const {
if (type_code_ == kNull) return Integer();
if (type_code_ == kTVMNullptr) return Integer();
if (type_code_ == kDLInt) {
CHECK_LE(value_.v_int64, std::numeric_limits<int>::max());
CHECK_GE(value_.v_int64, std::numeric_limits<int>::min());
return Integer(static_cast<int>(value_.v_int64));
}
TVM_CHECK_TYPE_CODE(type_code_, kObjectHandle);
TVM_CHECK_TYPE_CODE(type_code_, kTVMObjectHandle);
Object* ptr = static_cast<Object*>(value_.v_handle);
CHECK(ObjectTypeChecker<Integer>::Check(ptr))
<< "Expect type " << ObjectTypeChecker<PrimExpr>::TypeName()
Expand Down
58 changes: 20 additions & 38 deletions include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,62 +86,44 @@ typedef enum {
} TVMDeviceExtType;

/*!
* \brief The type code in TVMType
* \note TVMType is used in two places.
* \brief The type code in used in the TVM FFI.
*/
typedef enum {
// The type code of other types are compatible with DLPack.
// The next few fields are extension types
// that is used by TVM API calls.
kHandle = 3U,
kNull = 4U,
kTVMType = 5U,
kTVMOpaqueHandle = 3U,
kTVMNullptr = 4U,
kTVMDataType = 5U,
kTVMContext = 6U,
kArrayHandle = 7U,
kObjectHandle = 8U,
kModuleHandle = 9U,
kFuncHandle = 10U,
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kTVMDLTensorHandle = 7U,
kTVMObjectHandle = 8U,
kTVMModuleHandle = 9U,
kTVMPackedFuncHandle = 10U,
kTVMStr = 11U,
kTVMBytes = 12U,
kTVMNDArrayHandle = 13U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kExtBegin = 15U,
kNNVMFirst = 16U,
kNNVMLast = 20U,
kTVMExtBegin = 15U,
kTVMNNVMFirst = 16U,
kTVMNNVMLast = 20U,
// The following section of code is used for non-reserved types.
kExtReserveEnd = 64U,
kExtEnd = 128U,
kTVMExtReserveEnd = 64U,
kTVMExtEnd = 128U,
// The rest of the space is used for custom, user-supplied datatypes
kCustomBegin = 129U,
kTVMCustomBegin = 129U,
} TVMTypeCode;

/*!
* \brief The data type used in TVM Runtime.
*
* Examples
* - float: type_code = 2, bits = 32, lanes=1
* - float4(vectorized 4 float): type_code = 2, bits = 32, lanes=4
* - int8: type_code = 0, bits = 8, lanes=1
*
* \note Arguments TVM API function always takes bits=64 and lanes=1
*/
typedef DLDataType TVMType;

/*!
* \brief The Device information, abstract away common device types.
*/
typedef DLContext TVMContext;

/*!
* \brief The tensor array structure to TVM API.
*/
typedef DLTensor TVMArray;

/*! \brief the array handle */
typedef TVMArray* TVMArrayHandle;
typedef DLTensor* TVMArrayHandle;

/*!
* \brief Union type of values
Expand All @@ -152,13 +134,13 @@ typedef union {
double v_float64;
void* v_handle;
const char* v_str;
TVMType v_type;
DLDataType v_type;
TVMContext v_ctx;
} TVMValue;

/*!
* \brief Byte array type used to pass in byte array
* When kBytes is used as data type.
* When kTVMBytes is used as data type.
*/
typedef struct {
const char* data;
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/runtime/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class DataType {
kInt = kDLInt,
kUInt = kDLUInt,
kFloat = kDLFloat,
kHandle = TVMTypeCode::kHandle,
kHandle = TVMTypeCode::kTVMOpaqueHandle,
};
/*! \brief default constructor */
DataType() {}
Expand Down
6 changes: 3 additions & 3 deletions include/tvm/runtime/device_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class TVM_DLL DeviceAPI {
virtual void* AllocDataSpace(TVMContext ctx,
size_t nbytes,
size_t alignment,
TVMType type_hint) = 0;
DLDataType type_hint) = 0;
/*!
* \brief Free a data space on device.
* \param ctx The device context to perform operation.
Expand All @@ -115,7 +115,7 @@ class TVM_DLL DeviceAPI {
size_t num_bytes,
TVMContext ctx_from,
TVMContext ctx_to,
TVMType type_hint,
DLDataType type_hint,
TVMStreamHandle stream) = 0;
/*!
* \brief Create a new stream of execution.
Expand Down Expand Up @@ -177,7 +177,7 @@ class TVM_DLL DeviceAPI {
*/
virtual void* AllocWorkspace(TVMContext ctx,
size_t nbytes,
TVMType type_hint = {});
DLDataType type_hint = {});
/*!
* \brief Free temporal workspace in backend execution.
*
Expand Down
Loading