Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Deprecate Array
Browse files Browse the repository at this point in the history
  • Loading branch information
meta-project-ci committed Feb 12, 2020
1 parent ef0b748 commit 80d862a
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 92 deletions.
15 changes: 4 additions & 11 deletions include/mxnet/tuple.h
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,10 @@ class Tuple {

inline explicit Tuple(const runtime::ObjectRef& src) {
using namespace runtime;
if (const ADTObj* obj = src.as<ADTObj>()) {
this->SetDim(obj->size);
for (uint32_t i = 0; i < obj->size; ++i) {
this->begin()[i] = Downcast<Integer, ObjectRef>(obj->operator[](i))->value;
}
} else {
Array<IntImm> arr = Downcast<Array<IntImm>, ObjectRef>(src);
this->SetDim(arr.size());
for (size_t i = 0; i < arr.size(); ++i) {
this->begin()[i] = arr[i]->value;
}
ADT adt = Downcast<ADT, ObjectRef>(src);
this->SetDim(adt.size());
for (int i = 0; i < ndim_; ++i) {
this->begin()[i] = Downcast<Integer, ObjectRef>(adt[i])->value;
}
}

Expand Down
45 changes: 6 additions & 39 deletions python/mxnet/_ffi/_cython/convert.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -20,87 +20,54 @@
from libc.stdint cimport *
from numbers import Integral

cdef extern from "mxnet/runtime/object.h" namespace "mxnet::runtime":
cdef extern from "mxnet/runtime/ffi_helper.h" namespace "mxnet::runtime":
cdef cppclass Object:
pass

cdef cppclass ObjectPtr[T]:
ObjectPtr()

cdef cppclass ObjectRef:
ObjectRef()
ObjectRef(ObjectPtr[Object])
Object* get()

cdef ObjectPtr[T] GetObjectPtr[T](T* ptr)

pass

cdef extern from "mxnet/runtime/container.h" namespace "mxnet::runtime":
cdef cppclass ObjectRef:
const Object* get() const

cdef cppclass ADT(ObjectRef):
ADT()


cdef extern from "mxnet/runtime/ffi_helper.h" namespace "mxnet::runtime":
cdef cppclass ADTBuilder:
ADTBuilder()
ADTBuilder(uint32_t tag, uint32_t size)
void EmplaceInit(size_t idx, ObjectRef)
ADT Get()

cdef ObjectRef CreateEllipsis()

cdef cppclass Slice(ObjectRef):
Slice()
Slice(int)
Slice(int, int, int)

cdef cppclass Integer(ObjectRef):
Integer()
Integer(int)

cdef int64_t SliceNoneValue()


cdef extern from "mxnet/runtime/memory.h" namespace "mxnet::runtime":
cdef ObjectPtr[T] make_object[T]()
Integer(int64_t)


cdef inline ADT convert_tuple(tuple src_tuple) except *:
cdef uint32_t size = len(src_tuple)
cdef ADTBuilder builder = ADTBuilder(0, size)

for i in range(size):
builder.EmplaceInit(i, convert_object(src_tuple[i]))

return builder.Get()


cdef inline ADT convert_list(list src) except *:
cdef uint32_t size = len(src)
cdef ADTBuilder builder = ADTBuilder(0, size)

for i in range(size):
builder.EmplaceInit(i, convert_object(src[i]))

return builder.Get()

# cdef inline Slice convert_slice(slice slice_obj) except *:
# cdef int64_t kNoneValue = SliceNoneValue()
# return Slice(<int>(slice_obj.start) if slice_obj.start is not None else kNoneValue,
# <int>(slice_obj.stop) if slice_obj.stop is not None else kNoneValue,
# <int>(slice_obj.step) if slice_obj.step is not None else kNoneValue)


cdef inline ObjectRef convert_object(object src_obj) except *:
if isinstance(src_obj, int):
return Integer(<int>src_obj)
return Integer(<int64_t>src_obj)
elif isinstance(src_obj, tuple):
return convert_tuple(src_obj)
elif isinstance(src_obj, list):
return convert_list(src_obj)
if isinstance(src_obj, Integral):
return Integer(<int>src_obj)
return Integer(<int64_t>src_obj)
else:
raise TypeError("Don't know how to convert type %s" % type(src_obj))
12 changes: 4 additions & 8 deletions python/mxnet/_ffi/node_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,11 @@ def convert_to_node(value):
node : Node
The corresponding node value.
"""
if isinstance(value, bool):
return const(value, 'uint1x1')
if isinstance(value, Number):
return const(value)
if isinstance(value, (list, tuple)):
if isinstance(value, Integral):
return _api_internal._Integer(value)
elif isinstance(value, (list, tuple)):
value = [convert_to_node(x) for x in value]
return _api_internal._Array(*value)
if value is None:
return None
return _api_internal._ADT(*value)
raise ValueError("don't know how to convert type %s to node" % type(value))


Expand Down
58 changes: 41 additions & 17 deletions src/api/api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,33 +21,59 @@
* Implementation of API functions related to Higher DSL build.
* \file api_lang.cc
*/

#include <mxnet/runtime/packed_func.h>
#include <mxnet/api_registry.h>
#include <mxnet/base.h>
#include <mxnet/ir/expr.h>
#include <mxnet/node/container.h>
#include <mxnet/expr_operator.h>
#include <mxnet/runtime/packed_func.h>
#include <mxnet/ir/expr.h>
#include <mxnet/runtime/container.h>
#include <mxnet/runtime/ffi_helper.h>
#include <nnvm/c_api.h>
#include <iostream>

namespace mxnet {

MXNET_REGISTER_GLOBAL("_const")
// MXNET_REGISTER_GLOBAL("_const")
// .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
// if (args[0].type_code() == kDLInt) {
// *ret = make_const(args[1].operator MXNetDataType(),
// args[0].operator int64_t());
// } else if (args[0].type_code() == kDLFloat) {
// *ret = make_const(args[1].operator MXNetDataType(),
// args[0].operator double());
// } else {
// LOG(FATAL) << "only accept int or float";
// }
// });

MXNET_REGISTER_GLOBAL("_Integer")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
if (args[0].type_code() == kDLInt) {
*ret = make_const(args[1].operator MXNetDataType(),
args[0].operator int64_t());
} else if (args[0].type_code() == kDLFloat) {
*ret = make_const(args[1].operator MXNetDataType(),
args[0].operator double());
*ret = Integer(args[0].operator int64_t());
} else {
LOG(FATAL) << "only accept int or float";
LOG(FATAL) << "only accept int";
}
});
});

MXNET_REGISTER_GLOBAL("_Array")
// MXNET_REGISTER_GLOBAL("_Array")
// .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
// std::vector<ObjectRef> data;
// for (int i = 0; i < args.size(); ++i) {
// if (args[i].type_code() != kNull) {
// data.push_back(args[i].operator ObjectRef());
// } else {
// data.emplace_back(nullptr);
// }
// }
// auto node = make_object<ArrayNode>();
// node->data = std::move(data);
// *ret = Array<ObjectRef>(node);
// });

MXNET_REGISTER_GLOBAL("_ADT")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
std::vector<ObjectRef> data;
for (int i = 0; i < args.size(); ++i) {
if (args[i].type_code() != kNull) {
Expand All @@ -56,10 +82,8 @@ MXNET_REGISTER_GLOBAL("_Array")
data.emplace_back(nullptr);
}
}
auto node = make_object<ArrayNode>();
node->data = std::move(data);
*ret = Array<ObjectRef>(node);
});
*ret = ADT(0, data.begin(), data.end());
});

MXNET_REGISTER_GLOBAL("_Test")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
Expand Down
23 changes: 6 additions & 17 deletions src/api/api_npi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,13 @@ inline static void _npi_tensordot(runtime::MXNetArgs args,
op::TensordotParam param;
nnvm::NodeAttrs attrs;
attrs.op = op;
const ObjectRef ref = args[2].operator ObjectRef();
if (const ADTObj* obj = ref.as<ADTObj>()) {
if (const IntegerObj* lop = (*obj)[0].as<IntegerObj>()) {
param.a_axes_summed = Tuple<int>(1, lop->value);
param.b_axes_summed = Tuple<int>(1, Downcast<Integer, ObjectRef>((*obj)[1])->value);
} else {
param.a_axes_summed = Tuple<int>((*obj)[0]);
param.b_axes_summed = Tuple<int>((*obj)[1]);
}
ADT adt = Downcast<ADT, ObjectRef>(args[2].operator ObjectRef());
if (const IntegerObj* lop = adt[0].as<IntegerObj>()) {
param.a_axes_summed = Tuple<int>(1, lop->value);
param.b_axes_summed = Tuple<int>(1, Downcast<Integer, ObjectRef>(adt[1])->value);
} else {
Array<ObjectRef> arr = Downcast<Array<ObjectRef>, ObjectRef>(ref);
if (const IntImmNode* lop = arr[0].as<IntImmNode>()) {
param.a_axes_summed = Tuple<int>(1, lop->value);
param.b_axes_summed = Tuple<int>(1, Downcast<IntImm, ObjectRef>(arr[1])->value);
} else {
param.a_axes_summed = Tuple<int>(arr[0]);
param.b_axes_summed = Tuple<int>(arr[1]);
}
param.a_axes_summed = Tuple<int>(adt[0]);
param.b_axes_summed = Tuple<int>(adt[1]);
}
attrs.parsed = std::move(param);
int num_outputs = 0;
Expand Down

0 comments on commit 80d862a

Please sign in to comment.