diff --git a/.gitignore b/.gitignore index 6be98c50466d..90a05eaa2cdf 100644 --- a/.gitignore +++ b/.gitignore @@ -127,6 +127,11 @@ CMakeFiles cmake_install.cmake lib +# Kate / Kdevelop files +*.kate-swp +*.kdev4 + + # Visual Studio Code .vscode diff --git a/.gitmodules b/.gitmodules index cdb8a5536793..ea391ea584a6 100644 --- a/.gitmodules +++ b/.gitmodules @@ -26,3 +26,6 @@ [submodule "3rdparty/cub"] path = 3rdparty/cub url = https://github.com/dmlc/cub +[submodule "3rdparty/nnpack/NNPACK"] + path = 3rdparty/NNPACK + url = https://github.com/Maratyszcza/NNPACK diff --git a/3rdparty/NNPACK b/3rdparty/NNPACK new file mode 160000 index 000000000000..83af25db1188 --- /dev/null +++ b/3rdparty/NNPACK @@ -0,0 +1 @@ +Subproject commit 83af25db11883e160e65005f065f260488643c26 diff --git a/CMakeLists.txt b/CMakeLists.txt index 16d365355ceb..6f97af51ad38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -34,6 +34,7 @@ mxnet_option(USE_VTUNE "Enable use of Intel Amplifier XE (VTune)" OFF mxnet_option(ENABLE_CUDA_RTC "Build with CUDA runtime compilation support" ON) mxnet_option(INSTALL_EXAMPLES "Install the example source files." OFF) mxnet_option(USE_SIGNAL_HANDLER "Print stack traces on segfaults." OFF) +mxnet_option(USE_NNPACK "Build with NNPack support." OFF) if(USE_CUDA AND NOT USE_OLDCMAKECUDA) message(STATUS "CMake version '${CMAKE_VERSION}' using generator '${CMAKE_GENERATOR}'") @@ -551,6 +552,52 @@ if(NOT EXISTS "${CMAKE_CURRENT_SOURCE_DIR}/nnvm/CMakeLists.txt") list(APPEND mxnet_LINKER_LIBS ${nnvm_LINKER_LIBS}) endif() +# ---[ NNPack +if(USE_NNPACK) + if (USE_MKLDNN) + message(FATAL_ERROR "Either MKLDNN or NNPack can be enabled but not both.") + endif() + # Add in NNPack and its dependencies + set(NNPACK_SOURCE_DIR "${CMAKE_SOURCE_DIR}/3rdparty/NNPACK") + if (EXISTS "${NNPACK_SOURCE_DIR}") + if (GOOGLETEST_SOURCE_DIR AND EXISTS "${GOOGLETEST_SOURCE_DIR}") + set(GOOGLETEST_SOURCE_DIR "${GTEST_ROOT}" CACHE STRING "Google Test source directory") + endif() + + # Disable NNPack internal testing + set(NNPACK_BUILD_TESTS OFF CACHE BOOL "") + set(NNPACK_BUILD_BENCHMARKS OFF CACHE BOOL "") + + # Compile statically + set(NNPACK_LIBRARY_TYPE "static" CACHE STRING "") + set(PTHREADPOOL_LIBRARY_TYPE "static" CACHE STRING "") + set(CPUINFO_LIBRARY_TYPE "static" CACHE STRING "") + + # put NNPack dependencies in appropriate folders, see NNPack + # for other options + set(CONFU_DEPENDENCIES_SOURCE_DIR "${NNPACK_SOURCE_DIR}/deps" + CACHE PATH "Confu-style dependencies source directory") + set(CONFU_DEPENDENCIES_BINARY_DIR "${CMAKE_BINARY_DIR}/3rdparty/NNPACK/deps" + CACHE PATH "Confu-style dependencies source directory") + + add_subdirectory("${NNPACK_SOURCE_DIR}") + + # compile with -fPIC + set_property(TARGET nnpack PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET pthreadpool PROPERTY POSITION_INDEPENDENT_CODE ON) + set_property(TARGET cpuinfo PROPERTY POSITION_INDEPENDENT_CODE ON) + + include_directories(${NNPACK_SOURCE_DIR}/include) + #PTHREADPOOL_SOURCE_DIR is set in cache by NNPack + include_directories(${PTHREADPOOL_SOURCE_DIR}/include) + add_definitions(-DMXNET_USE_NNPACK=1) + set(NNPack_LINKER_LIBS nnpack) + list(APPEND mxnet_LINKER_LIBS ${NNPack_LINKER_LIBS}) + else() + message("NNPack submodule not found.") + endif() +endif() + if(NOT MSVC) # Only add c++11 flags and definitions after cuda compiling add_definitions(-DDMLC_USE_CXX11) diff --git a/src/operator/convolution_v1.cc b/src/operator/convolution_v1.cc index 86c0fbb33291..7d561f7d5ca8 100644 --- a/src/operator/convolution_v1.cc +++ b/src/operator/convolution_v1.cc @@ -25,9 +25,6 @@ */ #include "./convolution_v1-inl.h" -#if MXNET_USE_NNPACK == 1 -#include "./nnpack/nnpack_convolution-inl.h" -#endif // MXNET_USE_NNPACK namespace mxnet { namespace op { diff --git a/src/operator/nn/convolution.cc b/src/operator/nn/convolution.cc index 951063fb4b2f..69ed4713d2cf 100644 --- a/src/operator/nn/convolution.cc +++ b/src/operator/nn/convolution.cc @@ -28,9 +28,7 @@ #include "../elemwise_op_common.h" #include "./mkldnn/mkldnn_ops-inl.h" #include "./mkldnn/mkldnn_base-inl.h" -#if MXNET_USE_NNPACK == 1 -#include "./nnpack/nnpack_convolution-inl.h" -#endif // MXNET_USE_NNPACK +#include "./nnpack/nnpack_ops-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/fully_connected.cc b/src/operator/nn/fully_connected.cc index 4362408a23a1..b2b202a4ce08 100644 --- a/src/operator/nn/fully_connected.cc +++ b/src/operator/nn/fully_connected.cc @@ -25,9 +25,7 @@ #include "./fully_connected-inl.h" #include "./mkldnn/mkldnn_ops-inl.h" #include "./mkldnn/mkldnn_base-inl.h" -#if MXNET_USE_NNPACK == 1 -#include "./nnpack/nnpack_fully_connected-inl.h" -#endif // MXNET_USE_NNPACK +#include "./nnpack/nnpack_ops-inl.h" namespace mxnet { namespace op { diff --git a/src/operator/nn/nnpack/nnpack_ops-inl.h b/src/operator/nn/nnpack/nnpack_ops-inl.h new file mode 100644 index 000000000000..279349206aa1 --- /dev/null +++ b/src/operator/nn/nnpack/nnpack_ops-inl.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file nnpack_ops-inl.h + * \brief + * \author David Braude +*/ + +#ifndef MXNET_OPERATOR_NN_NNPACK_NNPACK_OPS_INL_H_ +#define MXNET_OPERATOR_NN_NNPACK_NNPACK_OPS_INL_H_ + +#if MXNET_USE_NNPACK == 1 + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO: +// Convolutional layer +// Inference-optimized forward propagation (nnp_convolution_inference) +// Training-optimized forward propagation (nnp_convolution_output) +// Training-optimized backward input gradient update (nnp_convolution_input_gradient) +// Training-optimized backward kernel gradient update (nnp_convolution_kernel_gradient) +// Fully-connected layer +// Inference-optimized forward propagation (nnp_fully_connected_inference and nnp_fully_connected_inference_f16f32 version for FP16 weights) +// Training-optimized forward propagation (nnp_fully_connected_output) +// Max pooling layer +// Forward propagation, both for training and inference, (nnp_max_pooling_output) +// ReLU layer (with parametrized negative slope) +// Forward propagation, both for training and inference, optionally in-place, (nnp_relu_output) +// Backward input gradient update (nnp_relu_input_gradient) + +namespace mxnet { +namespace op { + +/* For softmax */ +void NNPACKSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data); + +} // namespace op +} // namespace mxnet +#endif // MXNET_USE_MKLDNN == 1 + +#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_OPS_INL_H_ diff --git a/src/operator/nnpack/nnpack_util.cc b/src/operator/nn/nnpack/nnpack_softmax.cc similarity index 59% rename from src/operator/nnpack/nnpack_util.cc rename to src/operator/nn/nnpack/nnpack_softmax.cc index 7d075e0554ba..a1246545e6c0 100644 --- a/src/operator/nnpack/nnpack_util.cc +++ b/src/operator/nn/nnpack/nnpack_softmax.cc @@ -18,20 +18,31 @@ */ /*! - * Copyright (c) 2016 by Contributors - * \file nnpack_util.cc + * \file nnpack_softmax.cc * \brief - * \author Wei Wu + * \author David Braude */ -#if MXNET_USE_NNPACK == 1 -#include "nnpack_util.h" +#include "../softmax-inl.h" +#include "./nnpack_ops-inl.h" + +#if MXNET_USE_NNPACK == 1 namespace mxnet { namespace op { -NNPACKInitialize nnpackinitialize; +void NNPACKSoftmaxForward(const nnvm::NodeAttrs& attrs, const OpContext &ctx, + const NDArray &in_data, const OpReqType &req, + const NDArray &out_data) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); +// enum nnp_status nnp_softmax_output( +// size_t batch_size, +// size_t channels, +// const float input[], +// float output[], +// pthreadpool_t threadpool); +} -} // namespace op -} // namespace mxnet -#endif // MXNET_USE_NNPACK +} // namespace op +} // namespace mxnet +#endif diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index f719e0753e08..d13d43b13255 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -25,12 +25,8 @@ */ #include "../elemwise_op_common.h" #include "./pooling-inl.h" -#if MXNET_USE_NNPACK == 1 -#include "./nnpack/nnpack_pooling-inl.h" -#endif // MXNET_USE_NNPACK -#if MXNET_USE_MKLDNN == 1 +// #include "./nnpack/nnpack_pooling-inl.h" #include "./mkldnn/mkldnn_pooling-inl.h" -#endif // MXNET_USE_MKLDNN namespace mxnet { namespace op { diff --git a/src/operator/nnpack/nnpack_convolution-inl.h b/src/operator/nnpack/nnpack_convolution-inl.h deleted file mode 100644 index 0e2c73693d15..000000000000 --- a/src/operator/nnpack/nnpack_convolution-inl.h +++ /dev/null @@ -1,124 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2016 by Contributors - * \file nnpack_convolution-inl.h - * \brief - * \author Carwin -*/ -#ifndef MXNET_OPERATOR_NNPACK_NNPACK_CONVOLUTION_INL_H_ -#define MXNET_OPERATOR_NNPACK_NNPACK_CONVOLUTION_INL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "../convolution-inl.h" -#include "nnpack.h" -#include "nnpack_util.h" - -namespace mxnet { -namespace op { - -template -class NNPACKConvolutionOp : public ConvolutionOp { - private: - ConvolutionParam param_; - - public: - explicit NNPACKConvolutionOp(ConvolutionParam p) - : ConvolutionOp(p) { - this->param_ = p; - } - - public: - virtual void Forward(const OpContext &ctx, const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - Tensor data = in_data[conv::kData].get(s); - const size_t batch_size = data.shape_[0]; - const size_t input_c = data.shape_[1]; - const size_t input_h = data.shape_[2]; - const size_t input_w = data.shape_[3]; - Shape<3> wmat_shape = - Shape3(param_.num_group, param_.num_filter / param_.num_group, - input_c / param_.num_group * param_.kernel[0] * - param_.kernel[1]); - Tensor wmat = - in_data[conv::kWeight].get_with_shape(wmat_shape, s); - Tensor out = out_data[conv::kOut].get(s); - nnp_size input_size = {input_w, input_h}; - nnp_padding input_padding = {param_.pad[0], param_.pad[1], param_.pad[0], - param_.pad[1]}; - nnp_size kernel_size = {param_.kernel[1], param_.kernel[0]}; - nnp_size output_subsampling = {param_.stride[1], param_.stride[0]}; - Tensor bias = in_data[conv::kBias].get(s); - - nnp_convolution_algorithm algorithm = nnp_convolution_algorithm_auto; - nnp_convolution_transform_strategy kts = nnp_convolution_transform_strategy_tuple_based; - nnp_status status = nnp_status_success; - if (batch_size == 1) { - status = nnp_convolution_inference( - algorithm, // enum nnp_convolution_algorithm, - kts, // enum nnp_convolution_transform_strategy, - input_c, // size_t input_channels, - param_.num_filter, // size_t output_channels, - input_size, // struct nnp_size input_size, - input_padding, // struct nnp_padding input_padding, - kernel_size, // struct nnp_size kernel_size, - output_subsampling, // struct nnp_size output_subsampling, - data.dptr_, // const float input[], - wmat.dptr_, // const float kernel[], - bias.dptr_, // const float bias[], - out.dptr_, // float output[], - nnpackinitialize.threadpool, // pthreadpool_t threadpool, - nullptr); - } else { - status = nnp_convolution_output( - algorithm, // enum nnp_convolution_algorithm algorithm, - batch_size, // size_t batch size of input tensor - input_c, // size_t input_channels, - param_.num_filter, // size_t output_channels, - input_size, // struct nnp_size input_size, - input_padding, // struct nnp_padding input_padding, - kernel_size, // struct nnp_size kernel_size, - data.dptr_, // const float input[], - wmat.dptr_, // const float kernel[], - bias.dptr_, // const float bias[], - out.dptr_, // float output[], - nnpackinitialize.threadpool, // pthreadpool_t threadpool, - nullptr); - } - if (nnp_status_success != status) { - LOG(FATAL) << "nnpack convolution feedforward failed status=" << status; - } - } -}; // class NNPACKConvolutionOp -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_NNPACK_NNPACK_CONVOLUTION_INL_H_ diff --git a/src/operator/nnpack/nnpack_fully_connected-inl.h b/src/operator/nnpack/nnpack_fully_connected-inl.h deleted file mode 100644 index d9412d20d0c1..000000000000 --- a/src/operator/nnpack/nnpack_fully_connected-inl.h +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2016 by Contributors - * \file nnpack_fully_connected-inl.h - * \brief - * \author Wei Wu -*/ -#ifndef MXNET_OPERATOR_NNPACK_NNPACK_FULLY_CONNECTED_INL_H_ -#define MXNET_OPERATOR_NNPACK_NNPACK_FULLY_CONNECTED_INL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "../fully_connected-inl.h" -#include "nnpack.h" -#include "nnpack_util.h" - -namespace mxnet { -namespace op { - -template -class NNPACKFullyConnectedOp : public FullyConnectedOp { - private: - FullyConnectedParam param_; - - public: - explicit NNPACKFullyConnectedOp(FullyConnectedParam p) - : FullyConnectedOp(p) { - this->param_ = p; - } - - public: - virtual void Forward(const OpContext &ctx, const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - if (req[fullc::kOut] == kNullOp) return; - CHECK_EQ(req[fullc::kOut], kWriteTo); - size_t expected = param_.no_bias ? 2 : 3; - CHECK_EQ(in_data.size(), expected); - CHECK_EQ(out_data.size(), 1); - const TShape& ishape = in_data[fullc::kData].shape_; - const TShape& oshape = out_data[fullc::kOut].shape_; - Stream *s = ctx.get_stream(); - Tensor data = in_data[fullc::kData].get_with_shape( - Shape2(ishape[0], ishape.ProdShape(1, ishape.ndim())), s); - Tensor wmat = in_data[fullc::kWeight].get(s); - Tensor out = out_data[fullc::kOut].get_with_shape( - Shape2(oshape[0], oshape.ProdShape(1, oshape.ndim())), s); - const size_t batch_size = data.shape_[0]; - const size_t input_c = data.shape_[1]; - nnp_status status = nnp_status_success; - if (batch_size == 1) { - status = nnp_fully_connected_inference( - input_c, // size_t input_channels, - param_.num_hidden, // size_t output_channels, - data.dptr_, // const float input[], - wmat.dptr_, // const float kernel[], - out.dptr_, // float output[], - nnpackinitialize.threadpool); // pthreadpool_t threadpool, - } else { - status = nnp_fully_connected_output( - batch_size, // size_t batch size of input tensor - input_c, // size_t input_channels, - param_.num_hidden, // size_t output_channels, - data.dptr_, // const float input[], - wmat.dptr_, // const float kernel[], - out.dptr_, // float output[], - nnpackinitialize.threadpool, // pthreadpool_t threadpool, - nullptr); - } - if (nnp_status_success != status) { - LOG(FATAL) << "nnpack fully conneted feedforward failed status=" << status; - } - if (!param_.no_bias) { - Tensor bias = in_data[fullc::kBias].get(s); - out += repmat(bias, data.size(0)); - } - } -}; // class NNPACKFullyConnectedOp -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_NNPACK_NNPACK_FULLY_CONNECTED_INL_H_ diff --git a/src/operator/nnpack/nnpack_pooling-inl.h b/src/operator/nnpack/nnpack_pooling-inl.h deleted file mode 100644 index 25b478322753..000000000000 --- a/src/operator/nnpack/nnpack_pooling-inl.h +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2016 by Contributors - * \file nnpack_pooling-inl.h - * \brief - * \author Wei Wu -*/ -#ifndef MXNET_OPERATOR_NNPACK_NNPACK_POOLING_INL_H_ -#define MXNET_OPERATOR_NNPACK_NNPACK_POOLING_INL_H_ - -#include -#include -#include -#include -#include -#include -#include -#include -#include "../pooling-inl.h" -#include "nnpack.h" -#include "nnpack_util.h" - -namespace mxnet { -namespace op { - -template -class NNPACKPoolingOp : public PoolingOp { - private: - PoolingParam param_; - - public: - explicit NNPACKPoolingOp(PoolingParam p) - : PoolingOp(p) { - this->param_ = p; - } - - public: - virtual void Forward(const OpContext &ctx, const std::vector &in_data, - const std::vector &req, - const std::vector &out_data, - const std::vector &aux_args) { - using namespace mshadow; - using namespace mshadow::expr; - Stream *s = ctx.get_stream(); - Tensor data = in_data[pool_enum::kData].get(s); - const size_t batch_size = data.shape_[0]; - const size_t input_c = data.shape_[1]; - const size_t input_h = data.shape_[2]; - const size_t input_w = data.shape_[3]; - Tensor out = out_data[pool_enum::kOut].get(s); - nnp_size input_size = {input_w, input_h}; - nnp_padding input_padding = {param_.pad[0], param_.pad[1], param_.pad[0], - param_.pad[1]}; - nnp_size kernel_size = {param_.kernel[1], param_.kernel[0]}; - nnp_size output_subsampling = {param_.stride[1], param_.stride[0]}; - nnp_status status = nnp_max_pooling_output( - batch_size, // size_t batch size of input tensor - input_c, // size_t input_channels, - input_size, // struct nnp_size input_size, - input_padding, // struct nnp_padding input_padding, - kernel_size, // struct nnp_size kernel_size, - output_subsampling, // struct nnp_size output_subsampling, - data.dptr_, // const float input[], - out.dptr_, // float output[], - nnpackinitialize.threadpool); // pthreadpool_t threadpool, - if (nnp_status_success != status) { - LOG(FATAL) << "nnpack max pooling feedforward failed status=" << status; - } - } -}; // class NNPACKPoolingOp -} // namespace op -} // namespace mxnet -#endif // MXNET_OPERATOR_NNPACK_NNPACK_POOLING_INL_H_ diff --git a/src/operator/nnpack/nnpack_util.h b/src/operator/nnpack/nnpack_util.h deleted file mode 100644 index 2edfb79ad46e..000000000000 --- a/src/operator/nnpack/nnpack_util.h +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * Copyright (c) 2016 by Contributors - * \file nnpack_util.h - * \brief - * \author Carwin -*/ -#ifndef MXNET_OPERATOR_NNPACK_NNPACK_UTIL_H_ -#define MXNET_OPERATOR_NNPACK_NNPACK_UTIL_H_ - -#include -#include -#include - -namespace mxnet { -namespace op { - -class NNPACKInitialize { - public: - pthreadpool_t threadpool; - - public: - NNPACKInitialize() { - nnp_status status = nnp_initialize(); - if (nnp_status_success != status) { - LOG(FATAL) << "nnp_initialize failed status=" << status; - } - int num_threads = dmlc::GetEnv("MXNET_CPU_NNPACK_NTHREADS", 4); - this->threadpool = pthreadpool_create(num_threads); - } - virtual ~NNPACKInitialize() { - nnp_status status = nnp_deinitialize(); - if (nnp_status_success != status) { - LOG(FATAL) << "nnp_deinitialize failed status=" << status; - } - pthreadpool_destroy(threadpool); - } -}; - -// nnpackinitialize will be used in all other nnpack op -extern NNPACKInitialize nnpackinitialize; - -} // namespace op -} // namespace mxnet - -#endif // MXNET_OPERATOR_NNPACK_NNPACK_UTIL_H_