Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MXNet FFI for Operator Imperative Invocation #17510

Merged
merged 28 commits into from
Feb 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b9c6ea1
Init
meta-project-ci Jan 6, 2020
6aafabd
Add nop
meta-project-ci Jan 16, 2020
78b885d
Add utility function SetInOut and Invoke
meta-project-ci Jan 16, 2020
939863a
Init ctypes
meta-project-ci Jan 19, 2020
06e75f6
Dispatch for default/CSR array
meta-project-ci Jan 21, 2020
08ba4c7
Refactor, register the funcs where they are used, except for _api_int…
meta-project-ci Jan 22, 2020
ee75e2f
Seperate tvm ffi api and legacy api
meta-project-ci Feb 3, 2020
f3922ff
Replace legacy zeros with new
meta-project-ci Feb 3, 2020
5004c9b
Fix numpy.int64 in shape
meta-project-ci Feb 3, 2020
5dd4799
Fix sanity
meta-project-ci Feb 3, 2020
8dd68bf
Fix
meta-project-ci Feb 4, 2020
490ce05
Remove python2 support
meta-project-ci Feb 4, 2020
b748349
Cleanup
meta-project-ci Feb 4, 2020
e3acc91
Fix ci
meta-project-ci Feb 4, 2020
e424a1f
Fix lint
meta-project-ci Feb 4, 2020
9848108
Revert rand_shape_nd
meta-project-ci Feb 5, 2020
6f9906d
Fix clang-tidy
meta-project-ci Feb 5, 2020
3aaad87
Support NDArray in ctypes
meta-project-ci Feb 6, 2020
a65380e
Using runtime
meta-project-ci Feb 6, 2020
e84b494
Conversion ctor
meta-project-ci Feb 6, 2020
167a825
Tensordot
meta-project-ci Feb 6, 2020
976fbd3
Tensordot backward
meta-project-ci Feb 6, 2020
1817c2a
Fix nop regression
meta-project-ci Feb 12, 2020
92f27b3
Deprecate Array
meta-project-ci Feb 12, 2020
f4fa30d
Fix comments
meta-project-ci Feb 16, 2020
ae194db
Fix comments
meta-project-ci Feb 19, 2020
be7db15
Add acknowledgement to incubator-tvm
meta-project-ci Feb 19, 2020
d29bb7d
Refactor according to comments
meta-project-ci Feb 23, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions ci/jenkins/Jenkins_steps.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,23 @@ utils = load('ci/Jenkinsfile_utils.groovy')

// mxnet libraries
mx_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy3/*.so'
mx_lib_cython = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so'
hzfan marked this conversation as resolved.
Show resolved Hide resolved

// Python wheels
mx_pip = 'build/*.whl'

// mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
mx_cmake_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_lib_no_tvm_op = 'build/libmxnet.so, build/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy3/*.so'
mx_cmake_lib_cython = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so'
// mxnet cmake libraries, in cmake builds we do not produce a libnvvm static library by default.
mx_cmake_lib_debug = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests'
mx_cmake_mkldnn_lib = 'build/libmxnet.so, build/libmxnet.a, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/3rdparty/dmlc-core/libdmlc.a, build/tests/mxnet_unit_tests, build/3rdparty/openmp/runtime/src/libomp.so'
mx_mkldnn_lib = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a'
mx_tensorrt_lib = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, lib/libnvonnxparser_runtime.so.0, lib/libnvonnxparser.so.0, lib/libonnx_proto.so, lib/libonnx.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so'
mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, build/tests/cpp/mxnet_unit_tests'
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so'
mx_lib_cpp_examples = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so'
mx_lib_cpp_capi = 'lib/libmxnet.so, lib/libmxnet.a, lib/libtvm_runtime.so, lib/libtvmop.so, lib/tvmop.conf, libsample_lib.so, lib/libmkldnn.so.1, lib/libmklml_intel.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so, build/tests/cpp/mxnet_unit_tests'
mx_lib_cpp_examples_no_tvm_op = 'lib/libmxnet.so, lib/libmxnet.a, build/libcustomop_lib.so, build/libcustomop_gpu_lib.so, build/libsubgraph_lib.so, 3rdparty/dmlc-core/libdmlc.a, 3rdparty/tvm/nnvm/lib/libnnvm.a, 3rdparty/ps-lite/build/libps.a, deps/lib/libprotobuf-lite.a, deps/lib/libzmq.a, build/cpp-package/example/*, python/mxnet/_cy3/*.so, python/mxnet/_ffi/_cy3/*.so'
mx_lib_cpp_examples_cpu = 'build/libmxnet.so, build/3rdparty/tvm/libtvm_runtime.so, build/libtvmop.so, build/tvmop.conf, build/cpp-package/example/*'

// Python unittest for CPU
Expand Down
48 changes: 48 additions & 0 deletions include/mxnet/api_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file api_registry.h
* \brief This file contains utilities related to
* the MXNet's global function registry.
*/
// Acknowledgement: This file originates from incubator-tvm
#ifndef MXNET_API_REGISTRY_H_
#define MXNET_API_REGISTRY_H_

#include <string>
#include <utility>
#include "runtime/registry.h"

namespace mxnet {
/*!
* \brief Register an API function globally.
* It simply redirects to MXNET_REGISTER_GLOBAL
*
* \code
* MXNET_REGISTER_API(MyPrint)
* .set_body([](MXNetArgs args, MXNetRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define MXNET_REGISTER_API(OpName) MXNET_REGISTER_GLOBAL(OpName)

} // namespace mxnet
#endif // MXNET_API_REGISTRY_H_
58 changes: 58 additions & 0 deletions include/mxnet/expr_operator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file expr_operator.h
* \brief Common operators defined for Expr.
*
* \note Most of the operator defined here perform simple constant folding
* when the type is int32 or int64 for simplifying the index expressions.
*/
// Acknowledgement: This file originates from incubator-tvm
// Acknowledgement: Most operator APIs originate from Halide.
#ifndef MXNET_EXPR_OPERATOR_H_
#define MXNET_EXPR_OPERATOR_H_

#include <mxnet/ir/expr.h>

namespace mxnet {

template<typename ValueType>
inline PrimExpr MakeConstScalar(MXNetDataType t, ValueType value) {
if (t.is_int()) return IntImm(t, static_cast<int64_t>(value));
if (t.is_float()) return FloatImm(t, static_cast<double>(value));
// customized type and uint is not supported for MXNet for now
LOG(FATAL) << "cannot make const for type " << t;
return PrimExpr();
}


template<typename ValueType>
inline PrimExpr make_const(MXNetDataType t, ValueType value) {
if (t.lanes() == 1) {
return MakeConstScalar(t, value);
} else {
LOG(FATAL) << "MXNetDataType::lanes() != 1 is not supported ";
}
return PrimExpr();
}

} // namespace mxnet

#endif // MXNET_EXPR_OPERATOR_H_
225 changes: 225 additions & 0 deletions include/mxnet/ir/expr.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file expr.h
* \brief Base expr nodes in MXNet.
*/
// Acknowledgement: This file originates from incubator-tvm
#ifndef MXNET_IR_EXPR_H_
#define MXNET_IR_EXPR_H_

#include <mxnet/runtime/object.h>
#include <mxnet/node/node.h>
#include <mxnet/node/container.h>
#include <mxnet/runtime/data_type.h>
#include <string>

namespace mxnet {

/*!
* \brief Base type of all the expressions.
* \sa Expr
*/
class BaseExprNode : public Object {
public:
static constexpr const char* _type_key = "Expr";
MXNET_DECLARE_BASE_OBJECT_INFO(BaseExprNode, Object);
};

/*!
* \brief Managed reference to BaseExprNode.
* \sa BaseExprNode
*/
class BaseExpr : public ObjectRef {
public:
/*! \brief Cosntructor */
BaseExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit BaseExpr(runtime::ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
/*! \brief The container type. */
using ContainerType = BaseExprNode;
};

/*!
* \brief Base node of all primitive expressions.
*
* A primitive expression deals with low-level
* POD data types and handles without
* doing life-cycle management for objects.
*
* PrimExpr is used in the low-level code
* optimizations and integer analysis.
*
* \sa PrimExpr
*/
class PrimExprNode : public BaseExprNode {
public:
/*!
* \brief The runtime data type of the primitive expression.
*
* MXNetDataType(dtype) provides coarse grained type information
* during compile time and runtime. It is eagerly built in
* PrimExpr expression construction and can be used for
* quick type checking.
*
* dtype is sufficient to decide the Type of the PrimExpr
* when it corresponds to POD value types such as i32.
*
* When dtype is MXNetDataType::Handle(), the expression could corresponds to
* a more fine-grained Type, and we can get the type by running lazy type inference.
*/
MXNetDataType dtype;

static constexpr const char* _type_key = "PrimExpr";
MXNET_DECLARE_BASE_OBJECT_INFO(PrimExprNode, BaseExprNode);
};

/*!
* \brief Reference to PrimExprNode.
* \sa PrimExprNode
*/
class PrimExpr : public BaseExpr {
public:
/*! \brief Cosntructor */
PrimExpr() {}
/*!
* \brief Cosntructor from object ptr.
* \param ptr The object pointer.
*/
explicit PrimExpr(runtime::ObjectPtr<Object> ptr) : BaseExpr(ptr) {}
/*!
* \brief construct from integer.
* \param value The value to be constructed.
*/
MXNET_DLL PrimExpr(int32_t value); // NOLINT(*)
/*!
* \brief construct from float.
* \param value The value to be constructed.
*/
MXNET_DLL PrimExpr(float value); // NOLINT(*)
/*!
* \brief construct from string.
* \param str The value to be constructed.
*/
MXNET_DLL PrimExpr(std::string str); // NOLINT(*)

/*! \return the data type of this expression. */
MXNetDataType dtype() const {
return static_cast<const PrimExprNode*>(get())->dtype;
}
/*! \brief The container type. */
using ContainerType = PrimExprNode;
};

/*!
* \brief Constant integer literals in the program.
* \sa IntImm
*/
class IntImmNode : public PrimExprNode {
public:
/*! \brief the Internal value. */
int64_t value;

static constexpr const char* _type_key = "IntImm";
MXNET_DECLARE_FINAL_OBJECT_INFO(IntImmNode, PrimExprNode);
};

/*!
* \brief Managed reference class to IntImmNode.
*
* \sa IntImmNode
*/
class IntImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
IntImm() {}
/*!
* \brief constructor from node.
*/
explicit IntImm(runtime::ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
MXNET_DLL IntImm(MXNetDataType dtype, int64_t value);
/*!
* \brief Get pointer to the internal value.
* \return the content of the integer.
*/
const IntImmNode* operator->() const {
return static_cast<const IntImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = IntImmNode;
};

/*!
* \brief Constant floating point literals in the program.
* \sa FloatImm
*/
class FloatImmNode : public PrimExprNode {
public:
/*! \brief The constant value content. */
double value;

static constexpr const char* _type_key = "FloatImm";
MXNET_DECLARE_FINAL_OBJECT_INFO(FloatImmNode, PrimExprNode);
};

/*!
* \brief Managed reference class to FloatImmNode.
*
* \sa FloatImmNode
*/
class FloatImm : public PrimExpr {
public:
/*!
* \brief Constructor
*/
FloatImm() {}
/*!
* \brief constructor from node.
*/
explicit FloatImm(runtime::ObjectPtr<Object> node) : PrimExpr(node) {}
/*!
* \brief Constructor.
* \param dtype The data type of the value.
* \param value The internal value.
*/
MXNET_DLL FloatImm(MXNetDataType dtype, double value);
/*!
* \brief Get pointer to the container.
* \return The pointer.
*/
const FloatImmNode* operator->() const {
return static_cast<const FloatImmNode*>(get());
}
/*! \brief type indicate the container type */
using ContainerType = FloatImmNode;
};

} // namespace mxnet
#endif // MXNET_IR_EXPR_H_
Loading