diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h index 8b1a229250df..bb5a37fc8794 100644 --- a/src/operator/activation-inl.h +++ b/src/operator/activation-inl.h @@ -22,6 +22,7 @@ * \brief Activation operator * \author Bing Xu */ + #ifndef MXNET_OPERATOR_ACTIVATION_INL_H_ #define MXNET_OPERATOR_ACTIVATION_INL_H_ @@ -34,6 +35,7 @@ #include #include #include "./operator_common.h" +#include "./mxnet_op.h" namespace mxnet { namespace op { @@ -75,9 +77,16 @@ class ActivationOp : public Operator { CHECK_EQ(in_data.size(), 1U); CHECK_EQ(out_data.size(), 1U); Stream *s = ctx.get_stream(); - Tensor data = in_data[activation::kData].FlatTo2D(s); - Tensor out = out_data[activation::kOut].FlatTo2D(s); - Assign(out, req[activation::kOut], F(data)); + const TBlob& input = in_data[activation::kData]; + const size_t sz = input.shape_.Size(); + if (sz) { + MXNET_ASSIGN_REQ_SWITCH(req[activation::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, sz, + out_data[activation::kOut].dptr(), + input.dptr()); + }); + } } virtual void Backward(const OpContext &ctx, @@ -93,14 +102,24 @@ class ActivationOp : public Operator { CHECK(in_data.size() == 1 && in_grad.size() == 1); CHECK_EQ(req.size(), 1U); Stream *s = ctx.get_stream(); - Tensor m_out_grad = out_grad[activation::kOut].FlatTo2D(s); - Tensor m_out_data = out_data[activation::kOut].FlatTo2D(s); - Tensor m_in_grad = in_grad[activation::kData].FlatTo2D(s); - Assign(m_in_grad, req[activation::kData], F(m_out_data) * m_out_grad); + const TBlob& m_out_grad = out_grad[activation::kOut]; + const TBlob& m_out_data = out_data[activation::kOut]; + const TBlob& m_in_grad = in_grad[activation::kData]; + const size_t sz = m_out_data.shape_.Size(); + if (sz) { + MXNET_ASSIGN_REQ_SWITCH(req[activation::kData], Req, { + mxnet_op::Kernel, Req>, xpu>::Launch( + s, sz, + m_in_grad.dptr(), + m_out_grad.dptr(), + m_out_data.dptr()); + }); + } } }; // class ActivationOp -// Decalre Factory function, used for dispatch specialization +// Declare Factory function, used for dispatch specialization template Operator* CreateOp(ActivationParam type, int dtype, const TShape& dshape); diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index b6feb238d4b6..e9cf4f317dac 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -216,6 +216,20 @@ MSHADOW_CINLINE void copy(const TBlob& to, const TBlob& from) { }) } +/*! \brief Binary op backward gradient OP wrapper */ +template +struct backward_grad { + /* \brief Backward calc with grad + * \param a - output grad + * \param args... - data to grad calculation op (what this is -- input, output, etc. -- varies) + * \return input grad + */ + template + MSHADOW_XINLINE static DType Map(DType a, Args... args) { + return DType(a * GRAD_OP::Map(args...)); + } +}; + /*! \brief Select assignment operation based upon the req value * Also useful for mapping mshadow Compute (F) to Kernel::Launch */ diff --git a/tests/cpp/include/test_op.h b/tests/cpp/include/test_op.h index f30fbe8e6981..4b46b80b597d 100644 --- a/tests/cpp/include/test_op.h +++ b/tests/cpp/include/test_op.h @@ -100,7 +100,8 @@ class BasicOperatorData { #endif , initializeForward_(0) // unit testing may call inits in any order based , initializeBackward_(0) // upon its use-case (ie may not want to run forward pass first) - , initializeCallback_(0) { + , initializeCallback_(0) + , generator_(new std::mt19937()) { opContext_.is_train = true; opContext_.run_ctx.stream = nullptr; @@ -123,10 +124,14 @@ class BasicOperatorData { shape_input_vec_.resize(opProp.ListArguments().size()); op_.reset(opProp.CreateOperatorEx(getContext(), &shape_input_vec_, in_type)); if (op_) { + const size_t output_count = opProp.ListOutputs().size(); + const size_t aux_count = opProp.ListAuxiliaryStates().size(); // Figure out what sort of blobs we need to allocate std::vector out_shape, aux_shape; + out_shape.resize(output_count); + aux_shape.resize(aux_count); opProp.InferShape(&shape_input_vec_, &out_shape, &aux_shape); - std::vector out_type, aux_type; + std::vector out_type(output_count, -1), aux_type(aux_count, -1); opProp.InferType(in_type, &out_type, &aux_type); // Allocate top blobs (input) @@ -174,9 +179,9 @@ class BasicOperatorData { initForward(opProp, in_type); if (!initializeBackward_++) { for (size_t x = 0, n = static_cast(opProp.NumVisibleOutputs()); x < n; ++x) { - CHECK_LT(x, c_.blob_input_vec_.size()); - allocateBlob(&c_.blob_out_grad_, c_.blob_input_vec_[x].shape_, - false, c_.blob_input_vec_[x].type_flag_); + CHECK_LT(x, c_.blob_output_vec_.size()); + allocateBlob(&c_.blob_out_grad_, c_.blob_output_vec_[x].shape_, + false, c_.blob_output_vec_[x].type_flag_); } for (size_t x = 0, n = c_.blob_input_vec_.size(); x < n; ++x) { @@ -197,6 +202,7 @@ class BasicOperatorData { /*! \brief Run operator forward */ void forward(const size_t count = 1) { + const std::vector req(c_.blob_output_vec_.size(), kWriteTo); // Possibly move data to/from CPU and GPU (outside of timing scope) MXNET_CUDA_ONLY(std::unique_ptr gpuData(isGPU_ ? new GPUOpData(c_, &opContext_) : nullptr)); @@ -206,7 +212,7 @@ class BasicOperatorData { for (size_t x = 0; x < count; ++x) { op()->Forward(opContext_, c_.blob_input_vec_, - {kWriteTo, kWriteTo, kWriteTo}, + req, c_.blob_output_vec_, c_.blob_aux_states_); } @@ -214,7 +220,7 @@ class BasicOperatorData { for (size_t x = 0; x < count; ++x) { MXNET_CUDA_ONLY(op()->Forward(opContext_, gpuData->blob_input_vec_, - {kWriteTo, kWriteTo, kWriteTo}, + req, gpuData->blob_output_vec_, gpuData->blob_aux_states_)); } @@ -223,6 +229,7 @@ class BasicOperatorData { /*! \brief Run operator backwards */ void backward(const size_t count = 1) { + const std::vector req(c_.blob_output_vec_.size(), kWriteTo); // Possibly move data to/from CPU and GPU (outside of timing scope) MXNET_CUDA_ONLY(std::unique_ptr gpuData(isGPU_ ? new GPUOpData(c_, &opContext_) : nullptr)); @@ -234,7 +241,7 @@ class BasicOperatorData { c_.blob_out_grad_, c_.blob_input_vec_, c_.blob_output_vec_, - {kWriteTo, kWriteTo, kWriteTo}, + req, c_.blob_in_grad_, c_.blob_aux_states_); } @@ -244,7 +251,7 @@ class BasicOperatorData { gpuData->blob_out_grad_, gpuData->blob_input_vec_, gpuData->blob_output_vec_, - {kWriteTo, kWriteTo, kWriteTo}, + req, gpuData->blob_in_grad_, gpuData->blob_aux_states_)); } @@ -386,6 +393,21 @@ class BasicOperatorData { copy(blob, sourceData, 0, sourceDataSize); } + void FillRandom() { + std::uniform_real_distribution distribution(-1.0, 1.0); + for (size_t j = 0, jn = this->c_.all_blob_vects_.size(); j < jn; ++j) { + std::vector *data_vect = this->c_.all_blob_vects_[j]; + if (data_vect) { + for (size_t i = 0, n = data_vect->size(); i < n; ++i) { + TBlob &blob = (*data_vect)[i]; + test::patternFill(&blob, [this, &distribution]() -> DType { + return distribution(generator()); + }); + } + } + } + } + /*! \brief Input and output blobs */ OpContext opContext_; @@ -520,6 +542,9 @@ class BasicOperatorData { return allocateBlob(&standalone_blobs_, dest, shape, isGPU, dtype); } + /*! \brief mt19937 generator for random number generator */ + std::mt19937& generator() { return *generator_; } + /*! \brief Performance timing categories */ enum TimingId { Forward, @@ -539,6 +564,9 @@ class BasicOperatorData { /*! \brief scoped lifecycle management of allocated blobs */ std::list> standalone_blobs_; + /*! \brief Per-test generator */ + std::unique_ptr generator_; + public: /*! Timing instrumentation */ test::perf::TimingInstrument timing_; @@ -675,7 +703,7 @@ class Validator { } const TBlob& b1 = bv1[idx]; const TBlob& b2 = bv2[idx]; - if (print && test::debugOutput) { + if (print && test::debug_output) { test::print(RunContext(), &(std::cout << "Blob 1:"), b1, true, true); test::print(RunContext(), &(std::cout << "Blob 2:"), b2, true, true); } diff --git a/tests/cpp/include/test_op_runner.h b/tests/cpp/include/test_op_runner.h new file mode 100644 index 000000000000..6d0b766eb378 --- /dev/null +++ b/tests/cpp/include/test_op_runner.h @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file test_op_runner.h + * \brief Run a generic operator + * \author Chris Olivier +*/ +#ifndef TEST_OP_RUNNER_H_ +#define TEST_OP_RUNNER_H_ + +#include +#include +#include +#include "./test_op.h" + +namespace mxnet { +namespace test { + +/*! + * \brief Generic operator random test data + * \tparam DType Main data type + * \tparam AccReal Secondary data type (if any) + */ +template +class GenericOperatorData : public test::op::BasicOperatorData { + public: + typedef DType DataType; + typedef AccReal AccRealType; + + /*! + * \brief Constructor + * \param isGPU Is this to be used on GPU? + * \param inputShape Input shape to the operator + */ + GenericOperatorData(const bool isGPU, const TShape& inputShape) + : test::op::BasicOperatorData(isGPU, inputShape) { + } + + /*! + * \brief Reset forward pass by filling everything with random values + */ + void resetForward() override { + test::op::BasicOperatorData::FillRandom(); + } + + /*! + * \brief Reset backward pass by filling everything with random values + */ + void resetBackward() override { + test::op::BasicOperatorData::FillRandom(); + } +}; + +/*! + * \brief Generic operator runner + * \tparam OperatorProp property class for a given operator (i.e. FullyConnectedProp, BatchNormProp) + * \tparam OperatorDataContainer Data container for forward and backward passes for some given + * data types + */ +template +class OperatorRunner { + public: + typedef typename OperatorDataContainer::DataType DType; + typedef typename OperatorDataContainer::AccRealType AccReal; + + /*! + * \brief Test operator forward pass + * \param isGPU Whether this test is for GPU + * \param inputShape Input data shape + * \param kwargs Operator parameters + * \param OutShapeFunction Output shape function override + * \param count Number of times to run in each direction + * \return OpInfo object for further opereator analysis + */ + test::op::OpInfo + RunGenericOperatorForward( + bool isGPU, + const TShape &inputShape, + const std::vector > &kwargs, + const size_t count = 1) { +#if MXNET_USE_CUDA + if (isGPU && !test::unitTestsWithCuda) { + LOG(INFO) << "GPU not found, running test as non-GPU"; + } +#else + isGPU = false; +#endif + test::op::OpInfo info = + test::op::createOpAndInfoF(isGPU, inputShape, kwargs); + info.data_->initForward(*info.prop_, &info.in_type_); + info.data_->forward(count); + return info; + } + + /*! + * \brief Test operator backward pass + * \param info OpInfo object from forward pass + * \param count + * \return OpInfo object for further opereator analysis + */ + test::op::OpInfo RunGenericOperatorBackward( + test::op::OpInfo *info, + const size_t count = 1) { + info->data_->initBackward(*info->prop_, &info->in_type_); + info->data_->backward(count); + return *info; + } + + /*! + * \brief Run operator forward and backward + * \param isGPU Whether this test is for GPU + * \param inputShape Input data shape + * \param kwargs Operator parameters + * \param OutShapeFunction Output shape function override + * \param count Number of times to run in each direction + * \return + */ + test::op::OpInfo RunBidirectional( + bool isGPU, + const TShape &inputShape, + const std::vector > &kwargs, + const size_t count = 1) { + test::op::OpInfo info = + RunGenericOperatorForward(isGPU, inputShape, kwargs, count); + return RunGenericOperatorBackward(&info, count); + } + + /*! + * \brief Timing test a generic operator + * \tparam PropType + * \tparam DType Data type + * \tparam AccReal Accumulative data type (if any) + * \param label Label for performance output + * \param isGPU Whether this test is for GPU + * \param stochastic Whether shape should be random (batch size, channels, hm, w) + * \param kwargs Operator parameters + * \param dim Data dimensions + * \param count Number of times to run in each direction + */ + void TimingTest(const std::string& label, + const bool isGPU, + const bool stochastic, + const test::op::kwargs_t& kwargs, + int dim = 0, + size_t count = 1, + TShape timing_shape = TShape()) { + std::cout << std::endl << std::flush; + +#ifdef NDEBUG + size_t COUNT = 50; +#else + size_t COUNT = 5; +#endif + if (mxnet::test::quick_test) { + COUNT = 2; + count = 1; + } + + test::perf::TimingInstrument timing; + + std::stringstream ss; + ss << "Timing: " << COUNT << " iterations of " << count << " calls"; + if (timing_shape.ndim()) { + ss << ", shape = " << timing_shape << std::endl << std::flush; + } + std::cout << ss.str(); + + for (size_t i = 0; i < COUNT; ++i) { + index_t batchSize = 1; + index_t channels = 1; + index_t depth = 1; + index_t height = 1; + index_t width = 1; + + if (!timing_shape.ndim()) { + do { + batchSize = stochastic ? test::rangedRand(1U, TES_BATCH_SIZE * 2U) : TIMING_BATCH_SIZE; + channels = stochastic ? test::rangedRand(1U, TEST_CHANNELS * 2U) : TIMING_CHANNELS; + depth = stochastic ? test::rangedRand(1U, TEST_DEPTH * 2U) : TIMING_DEPTH; + height = stochastic ? test::rangedRand(1U, TEST_DH * 2U) : TIMING_DH; + width = stochastic ? test::rangedRand(1U, TEST_DW * 2U) : TIMING_DW; + } while (stochastic && (height * width) == 1U); + } else { + dim = timing_shape.ndim() - 1; + } + + const size_t D = dim ? dim - 1U : test::rangedRand(0U, 2U); + + test::op::OpInfo info; + switch (D) { + case 0: + info = RunGenericOperatorForward(isGPU, + timing_shape.ndim() ? timing_shape + : TShape({batchSize, + channels, + width}), + kwargs, + count); + break; + case 1: + info = RunGenericOperatorForward(isGPU, + timing_shape.ndim()? timing_shape + : TShape({batchSize, + channels, + height, + width}), + kwargs, + count); + break; + case 2: + info = RunGenericOperatorForward(isGPU, + timing_shape.ndim() ? timing_shape + : TShape({batchSize, + channels, + depth, + height, + width}), + kwargs, + count); + break; + default: + CHECK(false) << "Unsupported dimension count: " << (D + 1); + } + if (info.data_.get()) { + RunGenericOperatorBackward(&info, count); + timing += info.data_->timing_; + } + } while (false); + + timing.print(&std::cout, label); + std::cout << std::endl << std::flush; + } + + protected: + static constexpr int TES_BATCH_SIZE = 5; + static constexpr int TEST_CHANNELS = 3; + static constexpr int TEST_DEPTH = 2; + static constexpr int TEST_DH = 2; + static constexpr int TEST_DW = 3; + + static constexpr int TIMING_BATCH_SIZE = 128; + static constexpr int TIMING_CHANNELS = 3; + static constexpr int TIMING_DEPTH = 2; + static constexpr int TIMING_DH = 64; + static constexpr int TIMING_DW = 64; +}; + +} // namespace test +} // namespace mxnet + +#endif // TEST_OP_RUNNER_H_ diff --git a/tests/cpp/include/test_util.h b/tests/cpp/include/test_util.h index a788bf389be8..492a0783d227 100644 --- a/tests/cpp/include/test_util.h +++ b/tests/cpp/include/test_util.h @@ -40,8 +40,9 @@ namespace mxnet { namespace test { extern bool unitTestsWithCuda; -extern bool debugOutput; +extern bool debug_output; extern bool quick_test; +extern bool performance_run; /*! \brief Pause VTune analysis */ struct VTunePause { diff --git a/tests/cpp/operator/activation_perf.cc b/tests/cpp/operator/activation_perf.cc new file mode 100644 index 000000000000..c0a42173c003 --- /dev/null +++ b/tests/cpp/operator/activation_perf.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file activation_perf.cc + * \brief Perf/profile run of ActivationOp + * \author Chris Olivier + */ + +#include +#include +#include +#include "../../src/operator/activation-inl.h" +#include "../include/test_op_runner.h" + +using namespace mxnet; + +typedef std::vector > kwargs_t; +const kwargs_t basic_activation_args = { }; + +/*! + * \brief Generic bidirectional sanity test + */ +TEST(ACTIVATION_PERF, ExecuteBidirectional) { + TShape shape({5, 5}); + kwargs_t kwargs = basic_activation_args; + kwargs.push_back({"act_type", "tanh"}); + test::OperatorRunner> runner; + runner.RunBidirectional(false, shape, kwargs, 1); +} + +/*! + * \brief ActivationOp timing test for CPU + */ +TEST(ACTIVATION_PERF, TimingCPU) { + kwargs_t kwargs = basic_activation_args; + // Which math function is arbitrary since it will have roughly constant timing among approaches + kwargs.push_back({"act_type", "tanh"}); + test::OperatorRunner> runner; + runner.RunBidirectional(false, {10, 10, 10, 10}, kwargs, 1); // prime code and cache + std::vector shapes; + if (test::performance_run) { + shapes = { + {1, 1, 28, 28}, + {1, 3, 28, 28}, + {50, 1, 18, 32}, + {50, 3, 18, 32}, + {20, 3, 128, 128} + }; + } else { + shapes = { + {1, 1, 28, 28}, + {50, 3, 18, 32}, + }; + } + for (const TShape &shape : shapes) { + runner.TimingTest("Activation Operator CPU", false, false, kwargs, 2, 10, shape); + } +} + +#if MXNET_USE_CUDA == 1 +/*! + * \brief ActivationOp timing test for GPU + */ +TEST(ACTIVATION_PERF, TimingGPU) { + kwargs_t kwargs = basic_activation_args; + // Which math function is arbitrary since it will have roughly constant timing among approaches + kwargs.push_back({"act_type", "tanh"}); + test::OperatorRunner> runner; + runner.RunBidirectional(true, {10, 10, 10, 10}, kwargs, 1); // prime code and cache + std::vector shapes = { + {1, 1, 28, 28}, + {1, 3, 28, 28}, + {50, 1, 18, 32}, + {50, 3, 18, 32}, + {20, 3, 128, 128} + }; + for (const TShape &shape : shapes) { + runner.TimingTest("Activation Operator GPU", true, false, kwargs, 2, 10, shape); + } +} +#endif // MXNET_USE_CUDA == 1 diff --git a/tests/cpp/operator/batchnorm_test.cc b/tests/cpp/operator/batchnorm_test.cc index f04593322858..0eca871c3e22 100644 --- a/tests/cpp/operator/batchnorm_test.cc +++ b/tests/cpp/operator/batchnorm_test.cc @@ -450,7 +450,7 @@ template static StreamType& dumpF(StreamType *os, const test::op::OpInfo& prop, const size_t x = 0) { - if (test::debugOutput) { + if (test::debug_output) { *os << std::endl; if (x) { *os << "=============================" << std::endl; @@ -476,7 +476,7 @@ template static StreamType& dumpB(StreamType *os, const test::op::OpInfo& prop, const size_t x = 0) { - if (test::debugOutput) { + if (test::debug_output) { *os << std::endl; if (x) { *os << "=============================" << std::endl; @@ -1019,7 +1019,7 @@ TEST(BATCH_NORM, Test2DBackward_Complex) { MSHADOW_REAL_TYPE_SWITCH_EX( mshadow::kFloat32, DType, AccReal, { - test::ScopeSet noDebugOutput(&test::debugOutput, false); + test::ScopeSet noDebugOutput(&test::debug_output, false); const TShape inputShape({9, 14, 16, 91}); test::op::OpInfoPair bi = testForwardAndBackward( @@ -1226,7 +1226,7 @@ class ChannelAxisTestData { std::vector> channel_data_; static void print(const std::string& label, const std::vector>& m) { - if (test::debugOutput) { + if (test::debug_output) { if (!label.empty()) { std::cout << label << ": "; } @@ -1248,7 +1248,7 @@ class ChannelAxisTestData { } static void print(const std::string& label, const TBlob& blob) { - if (test::debugOutput) { + if (test::debug_output) { if (!label.empty()) { std::cout << label << ": "; } @@ -1364,7 +1364,7 @@ TEST(BATCH_NORM, TestChannelAxisSaveAndLoad) { /*! \brief Insert the channel field `channelCount` into the shape at `channelAxis` position */ static TShape MakeShape(const std::vector& shape, - unsigned int channelAxis, + signed int channelAxis, const size_t channelCount) { if (channelAxis < 0) { channelAxis += shape.size() + 1; @@ -1533,7 +1533,7 @@ TEST(BATCH_NORM, TestChannelAxisSimple) { * backward result equivalence here implies correctness for other channel positions */ TEST(BATCH_NORM, TestChannelAxis) { - test::ScopeSet noDebugOutput(&test::debugOutput, false); + test::ScopeSet noDebugOutput(&test::debug_output, false); test::op::kwargs_t kwargs; const std::vector> shapes = diff --git a/tests/cpp/operator/fully_conn_perf.cc b/tests/cpp/operator/fully_conn_perf.cc new file mode 100644 index 000000000000..4cb4b4522a96 --- /dev/null +++ b/tests/cpp/operator/fully_conn_perf.cc @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file fully_conn_perf.cc + * \brief Sample for running C++ performance tests on a single operator. This method is also + * useful for profiling with vtune or gprof, avoiding the "noise" of python and executor + * \author Chris Olivier + */ + +#include +#include +#include "../../src/operator/fully_connected-inl.h" +#include "../include/test_op_runner.h" + +using namespace mxnet; + +typedef std::vector > kwargs_t; + +const kwargs_t basic_fullyconn_args = { {"num_hidden", "250"} }; + +/*! + * \brief Generic bidirectional sanity test + */ +TEST(FULLY_CONNECTED, ExecuteBidirectionalFullyConnected) { + TShape shape({5, 5}); + kwargs_t kwargs = basic_fullyconn_args; + test::OperatorRunner> runner; + runner.RunBidirectional(false, shape, kwargs, 1); +} + +/*! + * \brief Timing test for CPU + */ +TEST(FULLY_CONNECTED, FullyConnectedTimingCPU) { + kwargs_t kwargs = basic_fullyconn_args; + test::OperatorRunner> + runner; + runner.RunBidirectional(false, {10, 10, 10, 10}, kwargs, 1); // prime code and cache + std::vector shapes; + if (test::performance_run) { + shapes = { + {1, 1, 28, 28}, + {1, 3, 28, 28}, + {50, 1, 18, 32}, + {50, 3, 18, 32}, + {20, 3, 128, 128} + }; + } else { + shapes = { + {1, 1, 28, 28}, + {50, 3, 18, 32}, + }; + } + for (const TShape& shape : shapes) { + runner.TimingTest("Fully connected CPU", false, false, kwargs, 2, 10, shape); + } +} + +#if MXNET_USE_CUDA == 1 +/*! + * \brief Timing test for GPU + */ +TEST(FULLY_CONNECTED, FullyConnectedTimingGPU) { + kwargs_t kwargs = basic_fullyconn_args; + test::op::OpInfo info; + test::OperatorRunner> runner; + runner.RunBidirectional(true, {10, 10, 10, 10}, kwargs, 1); // prime code and cache + const std::vector shapes = { + {1, 1, 28, 28}, {1, 3, 28, 28}, + {50, 1, 18, 32}, {50, 3, 18, 32} + }; + for (const TShape& shape : shapes) { + runner.TimingTest("Fully connected GPU", true, false, kwargs, 2, 10, shape); + } +} +#endif // MXNET_USE_CUDA == 1 diff --git a/tests/cpp/test_main.cc b/tests/cpp/test_main.cc index 5434a704c090..eaf9e3c21910 100644 --- a/tests/cpp/test_main.cc +++ b/tests/cpp/test_main.cc @@ -38,11 +38,12 @@ static bool dumpCallback(const google_breakpad::MinidumpDescriptor& descriptor, namespace mxnet { namespace test { bool unitTestsWithCuda = false; #ifdef NDEBUG -bool debugOutput = false; +bool debug_output = false; #else -bool debugOutput = false; +bool debug_output = false; #endif bool quick_test = false; +bool performance_run = false; }} #if MXNET_USE_CUDA @@ -85,7 +86,9 @@ int main(int argc, char ** argv) { // override (ie force attempt CUDA) mxnet::test::unitTestsWithCuda = true; } else if (!strcmp(argv[x], "--debug")) { - mxnet::test::debugOutput = true; + mxnet::test::debug_output = true; + } else if (!strcmp(argv[x], "--perf")) { + mxnet::test::performance_run = true; } else if (!strcmp(argv[x], "--quick") || !strcmp(argv[x], "-q")) { mxnet::test::quick_test = true; }