diff --git a/doc/ui/cmd_argument/argument_outline.md b/doc/ui/cmd_argument/argument_outline.md index 98dadc270dcac..d6cc2c6ed7cc1 100644 --- a/doc/ui/cmd_argument/argument_outline.md +++ b/doc/ui/cmd_argument/argument_outline.md @@ -183,7 +183,7 @@ It looks like there are a lot of arguments. However, most of them are for develo -GPUgpu_id +GPUgpu_id √√√√ @@ -207,6 +207,11 @@ It looks like there are a lot of arguments. However, most of them are for develo √√√√ + +cudnn_conv_workspace_limit_in_mb +√√√√ + + RNN beam_size diff --git a/doc/ui/cmd_argument/detail_introduction.md b/doc/ui/cmd_argument/detail_introduction.md index 0d0362d022a72..07608e5edf740 100644 --- a/doc/ui/cmd_argument/detail_introduction.md +++ b/doc/ui/cmd_argument/detail_introduction.md @@ -163,6 +163,10 @@ - Choose path to dynamic load NVIDIA CUDA library, for instance, /usr/local/cuda/lib64. [Default]: LD_LIBRARY_PATH - type: string (default: "", null) +* `--cudnn_conv_workspace_limit_in_mb` + - Specify cuDNN max workspace limit, in units MB, 4096MB=4GB by default. + - type: int32 (default: 4096MB=4GB) + ## NLP: RNN/LSTM/GRU * `--rnn_use_batch` - Whether to use batch method for calculation in simple RecurrentLayer. diff --git a/paddle/cuda/include/hl_device_functions.cuh b/paddle/cuda/include/hl_device_functions.cuh index 88d950d6c1713..159c26f443cb1 100755 --- a/paddle/cuda/include/hl_device_functions.cuh +++ b/paddle/cuda/include/hl_device_functions.cuh @@ -48,5 +48,24 @@ inline __device__ double paddleAtomicAdd(double* address, double val) { } } // namespace paddle +/** + * @brief sum reduction + * + * @param[in,out] smem input data, better to use __shared__ memory. + * @param[in] tid thread index. + * @param[in] threads the total thread number used to reduce, + * such as, blockDim.x. + * + * @return smem[0]: the sum of each elements in smem. + */ +__device__ __forceinline__ +void simpleReduce(real* smem, int tid, int threads) { + for (unsigned int s = threads / 2; s > 0; s >>= 1) { + if (tid < s) { + smem[tid] += smem[tid + s]; + } + __syncthreads(); + } +} #endif /* HL_DEVICE_FUNCTIONS_CUH_ */ diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h index 17419790471a7..71e8f8e3a60c9 100644 --- a/paddle/cuda/include/hl_matrix.h +++ b/paddle/cuda/include/hl_matrix.h @@ -229,4 +229,40 @@ extern void hl_cossim_derivative(real* grad, int input2_height, real scale); +/** + * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel]. + * + * @param[in] A_d input matrix (M x N). + * @param[in] B_d input matrix (1 x channel). + * @param[in] channel width of B. + * @param[in] dimM height of A. + * @param[in] dimN width of A. + * @param[in] scale scalar used for addition. + * + */ +extern void hl_matrix_add_shared_bias(real* A_d, + real* B_d, + const int channel, + const int dimM, + const int dimN, + real scale); + +/** + * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel]. + * + * @param[in] B_d input matrix (1 x channel). + * @param[in] A_d input matrix (M x N). + * @param[in] channel width of B. + * @param[in] dimM height of A. + * @param[in] dimN width of A. + * @param[in] scale scalar used for addition. + * + */ +extern void hl_matrix_collect_shared_bias(real* B_d, + real* A_d, + const int channel, + const int dimM, + const int dimN, + real scale); + #endif /* HL_MATRIX_H_ */ diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h index f1f1020c84d46..e37b1275432ca 100644 --- a/paddle/cuda/include/stub/hl_matrix_stub.h +++ b/paddle/cuda/include/stub/hl_matrix_stub.h @@ -101,4 +101,17 @@ inline void hl_cossim_derivative(real* grad, int input2_height, real scale) {} +inline void hl_matrix_add_shared_bias(real* A_d, + real* B_d, + const int channel, + const int dimM, + const int dimN, + real scale) {} + +inline void hl_matrix_collect_shared_bias(real* B_d, + real* A_d, + const int channel, + const int dimM, + const int dimN, + real scale) {} #endif // HL_MATRIX_STUB_H_ diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc index b215c0f6e33a1..7810d0d10053d 100644 --- a/paddle/cuda/src/hl_cuda_cudnn.cc +++ b/paddle/cuda/src/hl_cuda_cudnn.cc @@ -20,6 +20,11 @@ limitations under the License. */ #include "hl_thread.ph" #include "hl_dso_loader.h" #include "paddle/utils/Logging.h" +#include "paddle/utils/CommandLineParser.h" + +P_DEFINE_int32(cudnn_conv_workspace_limit_in_mb, 4096, + "Specify cuDNN max workspace limit, in units MB, " + "4096MB=4GB by default."); namespace dynload { @@ -242,7 +247,7 @@ void hl_conv_workspace(hl_tensor_descriptor input, CHECK_NOTNULL(conv); // Specify workspace limit directly - size_t memoryLimitBytes = 8 * 1024 * 1024; + size_t memoryLimitBytes = (1LL << 20) * FLAGS_cudnn_conv_workspace_limit_in_mb; // cudnn convolution forward configuration cudnnTensorDescriptor_t fwd_src_desc = GET_TENSOR_DESCRIPTOR(input); diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu index 067e68c41e119..3df9f63f9e4b7 100644 --- a/paddle/cuda/src/hl_cuda_matrix.cu +++ b/paddle/cuda/src/hl_cuda_matrix.cu @@ -20,6 +20,7 @@ limitations under the License. */ #include "hl_sequence.h" #include "paddle/utils/Logging.h" #include "hl_device_functions.cuh" +#include "hl_gpu_matrix_kernel.cuh" DEFINE_MATRIX_UNARY_OP(Zero, a = 0); DEFINE_MATRIX_TERNARY_PARAMETER_OP(_add, TWO_PARAMETER, c = p1*a + p2*b); @@ -673,3 +674,89 @@ void hl_cossim_derivative(real* grad, input1_height, input2_height, scale); CHECK_SYNC("hl_cossim_derivate failed"); } + +__global__ void KeMatrixAddSharedBias(real* A, + real* B, + const int channel, + const int M, + const int N, + real scale) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int dim = N / channel; + if (index < M * N) { + int i = index % N; + i = i / dim; + A[index] += scale * B[i]; + } +} + +void hl_matrix_add_shared_bias(real* A_d, + real* B_d, + const int channel, + const int dimM, + const int dimN, + real scale) { + const int blocks = 512; + const int grids = DIVUP(dimM * dimN, blocks); + KeMatrixAddSharedBias<<>> + (A_d, B_d, channel, dimM, dimN, scale); + CHECK_SYNC("hl_matrix_add_shared_bias failed"); +} + + +template +__global__ void KeMatrixCollectSharedBias(real *B, + real *A, + const int channel, + const int M, + const int N, + const int dim, + const int limit, + real scale) { + if (dim < limit) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < channel) { + real sum = 0.0; + for (int i = 0; i < M; ++i) { + for (int j = 0; j < dim; ++j) { + sum += A[i * N + index * dim + j]; + } + } + B[index] += scale * sum; + } + } else { + const int tid = threadIdx.x; + const int bid = blockIdx.x; + __shared__ real smem[blockSize]; + real sum = 0.0; + for (int j = 0; j < ((dim * M + blockSize - 1) / blockSize); ++j) { + int n = j * blockSize + tid; + int m = n / dim; + int w = n % dim; + smem[tid] = (m < M && w < dim) ? A[m * N + bid * dim + w] : 0.0; + __syncthreads(); + simpleReduce(smem, tid, blockSize); + sum += smem[0]; + } + if (tid == 0) { + B[bid] += scale * sum; + } + } +} + +void hl_matrix_collect_shared_bias(real* B_d, + real* A_d, + const int channel, + const int dimM, + const int dimN, + real scale) { + const int dim = dimN / channel; + const int blocks = 256; + const int limit = 64; + int grids = (dimM * dim) < limit ? DIVUP(channel, blocks) : channel; + + KeMatrixCollectSharedBias + <<< grids, blocks, 0, STREAM_DEFAULT>>> + (B_d, A_d, channel, dimM, dimN, dim, limit, scale); + CHECK_SYNC("hl_matrix_collect_shared_bias failed"); +} diff --git a/paddle/cuda/src/hl_cuda_sparse.cuh b/paddle/cuda/src/hl_cuda_sparse.cuh index c3b98f4ebc38d..9cf2d5a843343 100644 --- a/paddle/cuda/src/hl_cuda_sparse.cuh +++ b/paddle/cuda/src/hl_cuda_sparse.cuh @@ -908,24 +908,6 @@ int findIndex(int* indice, int num, int index) { return (end - 1); } -/** - * @brief sum reduction - * - * @param[in,out] smem input data, better to use __shared__ memory. - * @param[in] tid local thread index. - * @param[in] blockDimX the size of blockDim.x. - * - * note: return smem[0]: the sum of each elements of smem. - */ -__device__ __forceinline__ -void reduce(real* smem, int tid, int blockDimX) { - for (unsigned int s = blockDimX / 2; s > 0; s >>= 1) { - if (tid < s) { - smem[tid] += smem[tid + s]; - } - __syncthreads(); - } -} /** * @brief sum columns of csr sparse matrix (csr_val), then add to a_val. diff --git a/paddle/gserver/layers/ConcatenateLayer.cpp b/paddle/gserver/layers/ConcatenateLayer.cpp index 52a7cb6f777c3..bb6709b8df330 100644 --- a/paddle/gserver/layers/ConcatenateLayer.cpp +++ b/paddle/gserver/layers/ConcatenateLayer.cpp @@ -97,7 +97,8 @@ void ConcatenateLayer::backward(const UpdateCallback& callback) { */ class ConcatenateLayer2 : public Layer { public: - explicit ConcatenateLayer2(const LayerConfig& config) : Layer(config) {} + explicit ConcatenateLayer2(const LayerConfig& config) : + Layer(config) {} ~ConcatenateLayer2() {} @@ -110,6 +111,8 @@ class ConcatenateLayer2 : public Layer { std::vector> projections_; std::vector projOutput_; std::vector> projCol_; + bool sharedBias_; + std::unique_ptr biases_; }; REGISTER_LAYER(concat2, ConcatenateLayer2); @@ -119,7 +122,6 @@ bool ConcatenateLayer2::init(const LayerMap& layerMap, /* Initialize the basic parent class */ if (!Layer::init(layerMap, parameterMap)) return false; - CHECK(!biasParameter_); CHECK_EQ(inputLayers_.size(), parameters_.size()); projections_.reserve(inputLayers_.size()); projCol_.reserve(inputLayers_.size()); @@ -137,6 +139,13 @@ bool ConcatenateLayer2::init(const LayerMap& layerMap, } CHECK_EQ(getSize(), endCol); + /* initialize biases_ */ + if (biasParameter_.get() != NULL) { + sharedBias_ = config_.shared_biases(); + size_t psize = config_.bias_size(); + biases_ = std::unique_ptr(new Weight(1, psize, biasParameter_)); + } + return true; } @@ -154,8 +163,17 @@ void ConcatenateLayer2::forward(PassType passType) { projOutput_[i].grad = output_.grad->subColMatrix(startCol, endCol); } - for (size_t i = 0; i != inputLayers_.size(); ++i) { - projections_[i]->forward(&getInput(i), &projOutput_[i], passType); + { + AsyncGpuBlock block; + for (size_t i = 0; i != inputLayers_.size(); ++i) { + projections_[i]->forward(&getInput(i), &projOutput_[i], passType); + } + } + + /* add the bias-vector */ + if (biases_) { + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + output_.value->addBias(*(biases_->getW()), 1, sharedBias_); } /* activation */ { @@ -170,6 +188,13 @@ void ConcatenateLayer2::backward(const UpdateCallback& callback) { backwardActivation(); } + AsyncGpuBlock block; + if (biases_ && biases_->getWGrad()) { + REGISTER_TIMER_INFO("Concat2BpBiasTimer", getName().c_str()); + biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBias_); + biases_->getParameterPtr()->incUpdate(callback); + } + for (size_t i = 0; i != inputLayers_.size(); ++i) { if (projections_[i]) { projections_[i]->backward(callback); diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp index 9ed9572139dc8..040510b7ad211 100644 --- a/paddle/gserver/layers/ConvBaseLayer.cpp +++ b/paddle/gserver/layers/ConvBaseLayer.cpp @@ -35,25 +35,12 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, filterSizeY_.push_back(conf.filter_size_y()); filterPixels_.push_back(filterSize_.back() * filterSizeY_.back()); channels_.push_back(conf.channels()); - imgSize_.push_back(conf.img_size()); - imgPixels_.push_back(imgSize_.back() * imgSize_.back()); + imgSizeH_.push_back(conf.img_size()); + imgSizeW_.push_back(conf.img_size()); groups_.push_back(conf.groups()); filterChannels_.push_back(conf.filter_channels()); - outputX_.push_back(conf.output_x()); - outputs_.push_back(outputX_.back() * outputX_.back()); - } - - /* initialize the weightList */ - CHECK(inputLayers_.size() == parameters_.size()); - for (size_t i = 0; i < inputLayers_.size(); i++) { - size_t height, width; - height = filterPixels_[i] * filterChannels_[i]; - width = numFilters_; - - // create a new weight - CHECK_EQ(parameters_[i]->getSize(), width * height); - Weight* w = new Weight(height, width, parameters_[i]); - weights_.emplace_back(w); + outputH_.push_back(conf.output_x()); + outputW_.push_back(conf.output_x()); } /* initialize the biases_ */ @@ -74,4 +61,34 @@ bool ConvBaseLayer::init(const LayerMap& layerMap, return true; } +size_t ConvBaseLayer::calOutputSize() { + auto clearAndReserve = [this](IntV* vec) { + vec->clear(); + vec->reserve(this->inputLayers_.size()); + }; + clearAndReserve(&imgSizeH_); + clearAndReserve(&imgSizeW_); + clearAndReserve(&outputH_); + clearAndReserve(&outputW_); + size_t layerSize = 0; + for (size_t i = 0; i < inputLayers_.size(); i++) { + imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); + imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); + if (imgSizeH_[i] == 0) + imgSizeH_[i] = config_.inputs(i).conv_conf().img_size(); + if (imgSizeW_[i] == 0) + imgSizeW_[i] = config_.inputs(i).conv_conf().img_size(); + outputH_.push_back( + outputSize(imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i])); + outputW_.push_back( + outputSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i])); + CHECK_EQ(outputH_[i], outputH_[0]); + CHECK_EQ(outputW_[i], outputW_[0]); + } + getOutput().setFrameHeight(outputH_[0]); + getOutput().setFrameWidth(outputW_[0]); + layerSize = outputH_[0] * outputW_[0] * size_t(numFilters_); + return layerSize; +} + } // namespace paddle diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h index eaeaebf43be25..316514acf1a0d 100644 --- a/paddle/gserver/layers/ConvBaseLayer.h +++ b/paddle/gserver/layers/ConvBaseLayer.h @@ -43,19 +43,18 @@ class ConvBaseLayer : public Layer { IntV filterSizeY_; /// The spatial dimensions of the convolution input. IntV channels_; - /// The spatial dimensions of input feature map. - IntV imgSize_; - /// The total pixel size of input feature map. - /// imgPixels_ = imgSizeX_ * imgSizeY_. - IntV imgPixels_; + /// The spatial dimensions of input feature map height. + IntV imgSizeH_; + /// The spatial dimensions of input feature map width. + IntV imgSizeW_; /// filterPixels_ = filterSizeX_ * filterSizeY_. IntV filterPixels_; /// filterChannels_ = channels_/groups_. IntV filterChannels_; - /// The spatial dimensions of output feature map. - IntV outputX_; - /// The spatial dimensions of output feature map. - IntV outputs_; + /// The spatial dimensions of output feature map height. + IntV outputH_; + /// The spatial dimensions of output feature map width. + IntV outputW_; /// Group size, refer to grouped convolution in /// Alex Krizhevsky's paper: when group=2, the first half of the /// filters are only connected to the first half of the input channels, @@ -80,6 +79,13 @@ class ConvBaseLayer : public Layer { virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); + /** + * imgSizeH_ and imgSizeW_ will be set according to the previous input layers + * in this function. Then it will calculate outputH_ and outputW_ and set them + * into output argument. + */ + virtual size_t calOutputSize(); + Weight& getWeight(int idx) { return *weights_[idx]; } /** diff --git a/paddle/gserver/layers/ConvProjection.cpp b/paddle/gserver/layers/ConvProjection.cpp new file mode 100644 index 0000000000000..d1ce53fe26351 --- /dev/null +++ b/paddle/gserver/layers/ConvProjection.cpp @@ -0,0 +1,210 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#include "paddle/utils/Stat.h" +#include "ConvProjection.h" + +namespace paddle { + +REGISTER_PROJECTION(conv, ConvProjection); + +ThreadLocalD> ConvProjection::convMem_; + +ConvProjection::ConvProjection(const ProjectionConfig& config, + ParameterPtr parameter, bool useGpu) + : Projection(config, parameter, useGpu) { + + CHECK(useGpu); // only support GPU + getConvParams(); + initCudnn(); + + size_t height = filterH_ * filterW_ * channels_ / groups_; + size_t width = numFilters_; + weight_.reset(new Weight(height, width, parameter)); + weightOffset_ = height * width / groups_; +} + +void ConvProjection::getConvParams() { + const ConvConfig &conf = config_.conv_conf(); + paddingH_ = conf.padding_y(); + paddingW_ = conf.padding(); + + strideH_ = conf.stride_y(); + strideW_ = conf.stride(); + + filterH_ = conf.filter_size_y(); + filterW_ = conf.filter_size(); + + configImgH_ = conf.img_size(); + configImgW_ = conf.img_size(); + + channels_ = conf.channels(); + numFilters_ = config_.num_filters(); + + groups_ = conf.groups(); + CHECK_EQ(channels_ % groups_, 0); + CHECK_EQ(numFilters_ % groups_, 0); +} + +void ConvProjection::initCudnn() { + hl_create_filter_descriptor(&filterDesc_, channels_, numFilters_, + filterH_, filterW_); + hl_create_tensor_descriptor(&inputDesc_); + hl_create_tensor_descriptor(&outputDesc_); + hl_create_convolution_descriptor(&convDesc_, inputDesc_, filterDesc_, + paddingH_, paddingW_, strideH_, strideW_); + + // initialize all to default algorithms + fwdAlgo_ = 0; + bwdFilterAlgo_ = 0; + bwdDataAlgo_ = 0; + fwdLimitBytes_ = 0; + bwdDataLimitBytes_ = 0; + bwdFilterLimitBytes_ = 0; + workSpaceInBytes_ = 0; + + batchNum_ = 0; + isSelectAlgo_ = false; +} + +void ConvProjection::reshapeTensorDesc(int batchSize) { + hl_tensor_reshape(inputDesc_, batchSize, channels_, imageH_, imageW_, + channels_ * imageH_ * imageW_, imageH_ * imageW_, + imageW_, 1); + hl_reset_convolution_descriptor(convDesc_, inputDesc_, filterDesc_, + paddingH_, paddingW_, strideH_, strideW_); + + // The stride between two consecutive images in ConvProjection may not be 1, + // for example, in the case of layer ConcatenateLayer2 with two + // ConvProjection, the stride is the output_size of layer ConcatenateLayer2. + // So the calculation of nStride is different from CudnnConvLayer. + // In fact, only "nStride = out_->value->getStride()" is ok. + size_t nStride = numFilters_ * outputH_ * outputW_; + if (out_->value->isContiguous()) { + CHECK_EQ(nStride, out_->value->getWidth()); + } else { + nStride = out_->value->getStride(); + } + + hl_tensor_reshape(outputDesc_, batchSize, numFilters_, outputH_, outputW_, + nStride, outputH_ * outputW_, outputW_, 1); +} + +void ConvProjection::reshape(int batchSize) { + size_t width = calOutputSize(); + CHECK_EQ(width, out_->value->getWidth()); + + isSelectAlgo_ = (batchSize == batchNum_); + batchNum_ = batchSize; + + if (!isSelectAlgo_) { + reshapeTensorDesc(batchSize); + hl_conv_workspace(inputDesc_, outputDesc_, filterDesc_, + convDesc_, &fwdAlgo_, &fwdLimitBytes_, + &bwdDataAlgo_, &bwdDataLimitBytes_, + &bwdFilterAlgo_, &bwdFilterLimitBytes_); + + size_t maxWorkSpace = 0; + maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_); + maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_); + workSpaceInBytes_ = maxWorkSpace; + + + VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_ + << " / " << bwdDataAlgo_ + << " / " << bwdFilterAlgo_; + } + + isSelectAlgo_ = true; +} + +void ConvProjection::forward() { + int batchSize = in_->value->getHeight(); + reshape(batchSize); + + void* workSpace = NULL; + if (workSpaceInBytes_ > 0) { + workSpace = getSpaceBytes(workSpaceInBytes_); + } + + for (int g = 0; g < groups_; ++g) { + REGISTER_TIMER_INFO("CudnnConvFwTimer", getName().c_str()); + + real *inputData = in_->value->getData() + g * inputOffset_; + real *wgtData = weight_->getW()->getData() + g * weightOffset_; + real *outData = out_->value->getData() + g * outputOffset_; + hl_convolution_forward(inputDesc_, inputData, outputDesc_, + outData, filterDesc_, wgtData, + convDesc_, workSpace, + fwdLimitBytes_, fwdAlgo_); + } +} + +void ConvProjection::backward(const UpdateCallback& callback) { + REGISTER_TIMER_INFO("CudnnConvBpTimer", getName().c_str()); + + void* workSpace = NULL; + if (workSpaceInBytes_ > 0) { + workSpace = getSpaceBytes(workSpaceInBytes_); + } + + for (int g = 0; g < groups_; ++g) { + real *outGrad = out_->grad->getData() + g * outputOffset_; + if (weight_->getWGrad()) { + real *inputData = in_->value->getData() + g * inputOffset_; + real *weightGrad = weight_->getWGrad()->getData() + g * weightOffset_; + hl_convolution_backward_filter( + inputDesc_, inputData, outputDesc_, outGrad, filterDesc_, + weightGrad, convDesc_, workSpace, bwdFilterLimitBytes_, + bwdFilterAlgo_); + } + + MatrixPtr preGrad = in_->grad; + if (NULL != preGrad) { + real *inputGrad = preGrad->getData() + g * inputOffset_; + real *wgtData = weight_->getW()->getData() + g* weightOffset_; + hl_convolution_backward_data( + inputDesc_, inputGrad, outputDesc_, outGrad, filterDesc_, + wgtData, convDesc_, workSpace, bwdDataLimitBytes_, + bwdDataAlgo_); + } + } + + weight_->getParameterPtr()->incUpdate(callback); +} + +void* ConvProjection::getSpaceBytes(size_t size) { + std::vector& convMem = *convMem_; + if (convMem.empty()) { + int numDevices = hl_get_device_count(); + convMem.resize(numDevices); + } + + int devId = hl_get_device(); + MemoryHandle** localMem = &(convMem[devId]); + if (NULL == *localMem || size > (*localMem)->getAllocSize()) { + *localMem = new GpuMemoryHandle(size); + } + return (*localMem)->getBuf(); +} + +ConvProjection::~ConvProjection() { + hl_destroy_tensor_descriptor(inputDesc_); + hl_destroy_tensor_descriptor(outputDesc_); + hl_destroy_filter_descriptor(filterDesc_); + hl_destroy_convolution_descriptor(convDesc_); +} + +} // namespace paddle diff --git a/paddle/gserver/layers/ConvProjection.h b/paddle/gserver/layers/ConvProjection.h new file mode 100644 index 0000000000000..41a100ac3c50f --- /dev/null +++ b/paddle/gserver/layers/ConvProjection.h @@ -0,0 +1,125 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + + +#pragma once + +#include "Projection.h" + +namespace paddle { + +/** + * @brief Convolution projection do the same calculation with CudnnConvLayer. + */ +class ConvProjection : public Projection { +public: + /** + * Constructor. + */ + ConvProjection(const ProjectionConfig& config, ParameterPtr parameter, + bool useGpu); + + ~ConvProjection(); + + virtual void forward(); + virtual void backward(const UpdateCallback& callback); + +protected: + void getConvParams(); + void initCudnn(); + + void reshapeTensorDesc(int batchSize); + void reshape(int batchSize); + + int outputSize(int imageSize, int filterSize, int padding, int stride) { + return (imageSize - filterSize + 2 * padding) / stride + 1; + } + + size_t calOutputSize() { + imageH_ = in_->getFrameHeight(); + imageW_ = in_->getFrameWidth(); + if (imageH_ == 0) imageH_ = configImgH_; + if (imageW_ == 0) imageW_ = configImgW_; + outputH_ = outputSize(imageH_, filterH_, paddingH_, strideH_); + outputW_ = outputSize(imageW_, filterW_, paddingW_, strideW_); + + const_cast(out_)->setFrameHeight(outputH_); + const_cast(out_)->setFrameWidth(outputW_); + + inputOffset_ = (channels_ / groups_) * imageH_ * imageW_; + outputOffset_ = (numFilters_ / groups_) * outputH_ * outputW_; + return outputH_ * outputW_ * numFilters_; + } + + static void* getSpaceBytes(size_t size); + + /// imageH_ and imageW_ is calculated from the input layer. + int imageH_, imageW_; + /// configImgH_ and configImgW_ is obtained from config. + int configImgH_, configImgW_; + int outputH_, outputW_; + int channels_, numFilters_; + int paddingH_, paddingW_; + int strideH_, strideW_; + int filterH_, filterW_; + /// One group offset of input data. + int inputOffset_; + /// One group offset of output data. + int outputOffset_; + /// One group offset of weight. + int weightOffset_; + int groups_; + + /// Cudnn tensor descriptor for input. + hl_tensor_descriptor inputDesc_; + /// Cudnn tensor descriptor for output. + hl_tensor_descriptor outputDesc_; + /// Cudnn tensor descriptor for filter. + hl_filter_descriptor filterDesc_; + /// Cudnn tensor descriptor for a convolution operation. + hl_convolution_descriptor convDesc_; + + /// Record the algorithm for forward convolution, which is obtained by cudnn + /// api to search the best suited algorithm. + int fwdAlgo_; + /// Record the algorithm for computing convolution gradient with respect to + /// filter coefficients. + int bwdFilterAlgo_; + /// Record the algorithm for computing convolution gradient with respect to + /// the output. + int bwdDataAlgo_; + /// Amount of GPU memory needed as workspace to be able to execute a + /// forward convolution with the specified algo. + size_t fwdLimitBytes_; + /// Amount of GPU memory needed as workspace to be able to execute a + /// backwardFilter with the specified algo. + size_t bwdDataLimitBytes_; + /// Amount of GPU memory needed as workspace to be able to execute a + /// backwardData with the specified algo. + size_t bwdFilterLimitBytes_; + /// Size of total work space. + size_t workSpaceInBytes_; + + /// Whether to call cuDNN api to choose conv algorithm. + bool isSelectAlgo_; + /// batchNum is used to record batch size. If the batch size is changed, + /// the selection algorithm will be called. + int batchNum_; + bool bias_; + + std::unique_ptr weight_; + static ThreadLocalD> convMem_; +}; + +} // namespace paddle diff --git a/paddle/gserver/layers/CudnnConvLayer.cpp b/paddle/gserver/layers/CudnnConvLayer.cpp index 0f932f960f6ba..e77216f17c3be 100644 --- a/paddle/gserver/layers/CudnnConvLayer.cpp +++ b/paddle/gserver/layers/CudnnConvLayer.cpp @@ -22,215 +22,64 @@ REGISTER_LAYER(cudnn_conv, CudnnConvLayer); bool CudnnConvLayer::init(const LayerMap &layerMap, const ParameterMap ¶meterMap) { - ConvBaseLayer::init(layerMap, parameterMap); + if (!ConvBaseLayer::init(layerMap, parameterMap)) return false; CHECK(useGpu_) << "CudnnConvLayer only support gpu"; - maxGroups_ = 0; - for (size_t i = 0; i < inputLayers_.size(); i++) { - CHECK_EQ(channels_[i] % groups_[i], 0); - CHECK_EQ(numFilters_ % groups_[i], 0); - - hl_filter_descriptor filter; - hl_create_filter_descriptor(&filter, channels_[i] / groups_[i], - numFilters_ / groups_[i], filterSizeY_[i], - filterSize_[i]); - filterDesc_.push_back(filter); - - hl_tensor_descriptor input; - hl_create_tensor_descriptor(&input); - inputDesc_.push_back(input); - - hl_tensor_descriptor output; - int outputX = - outputSize(imgSize_[i], filterSize_[i], padding_[i], stride_[i]); - CHECK_EQ(outputX, outputX_[i]); - hl_create_tensor_descriptor(&output); - outputDesc_.push_back(output); + CHECK_EQ(inputLayers_.size(), parameters_.size()); + projections_.reserve(inputLayers_.size()); + projConf_.reserve(inputLayers_.size()); - hl_convolution_descriptor conv; - hl_create_convolution_descriptor(&conv, input, filter, paddingY_[i], - padding_[i], strideY_[i], stride_[i]); - convDesc_.push_back(conv); - - weightOffset_.push_back((numFilters_ / groups_[i]) * - (channels_[i] / groups_[i]) * filterPixels_[i]); - inputOffset_.push_back((channels_[i] / groups_[i]) * imgSize_[i] * - imgSize_[i]); - outputOffset_.push_back((numFilters_ / groups_[i]) * outputX_[i] * - outputX_[i]); - - // initialize all to default algorithms - fwdAlgo_.push_back(0); - bwdFilterAlgo_.push_back(0); - bwdDataAlgo_.push_back(0); - fwdLimitBytes_.push_back(0); - bwdFilterLimitBytes_.push_back(0); - bwdDataLimitBytes_.push_back(0); - - // cudnn streams per group equal to 1 - if (groups_[i] > maxGroups_) { - maxGroups_ = groups_[i]; - } - } - - workSpaceInBytes_ = 0; - workSpaceData_ = NULL; - for (int i = 0; i < maxGroups_; ++i) { - workSpace_.push_back(NULL); + numFilters_ = config_.num_filters(); + CHECK(config_.shared_biases()); + for (size_t i = 0; i < inputLayers_.size(); i++) { + ProjectionConfig* conf = new ProjectionConfig(); + conf->set_type("conv"); + conf->set_num_filters(numFilters_); + conf->set_allocated_conv_conf( + config_.mutable_inputs(i)->mutable_conv_conf()); + conf->set_input_size(getPrev(i)->getSize()); + conf->set_output_size(getSize()); + projConf_.emplace_back(conf); + projections_.emplace_back(Projection::create(*projConf_[i], + parameters_[i], useGpu_)); } if (biases_.get() && sharedBiases_) { hl_create_tensor_descriptor(&biasDesc_); + hl_create_tensor_descriptor(&outputDesc_); hl_tensor_reshape(biasDesc_, 1, numFilters_ / groups_[0], 1, 1); biasOffset_ = numFilters_ / groups_[0]; } - batchNum_ = 0; - isSelectAlgo_ = false; return true; } -void CudnnConvLayer::allocConvWorkSpace(size_t maxWorkSpace) { - size_t totalWorkSpace = maxWorkSpace * maxGroups_; - - if (totalWorkSpace > workSpaceInBytes_) { - if (workSpaceInBytes_ != 0) { - hl_free_mem_device(workSpaceData_); - } - // total amount of storage needed over all groups - workSpaceData_ = hl_malloc_device(totalWorkSpace); - - // update work space address for each group - for (int i = 0; i < maxGroups_; ++i) { - workSpace_[i] = reinterpret_cast(workSpaceData_) - + i * maxWorkSpace; - } - workSpaceInBytes_ = totalWorkSpace; - } -} - -void CudnnConvLayer::reshape(int batchSize) { - CHECK_NE(inputLayers_.size(), 0UL); - imageH_ = inputLayers_[0]->getOutput().getFrameHeight(); - imageW_ = inputLayers_[0]->getOutput().getFrameWidth(); - if (imageH_ == 0) imageH_ = imgSize_[0]; - if (imageW_ == 0) imageW_ = imgSize_[0]; - - for (size_t i = 1; i < inputLayers_.size(); i++) { - int imageH = inputLayers_[i]->getOutput().getFrameHeight(); - int imageW = inputLayers_[i]->getOutput().getFrameWidth(); - if (imageH) { - CHECK_EQ(imageH_, imageH) << "Inputs must have same height."; - } - if (imageW) { - CHECK_EQ(imageW_, imageW) << "Inputs must have same width."; - } - } - - outputH_ = outputSize(imageH_, filterSizeY_[0], paddingY_[0], strideY_[0]); - outputW_ = outputSize(imageW_, filterSize_[0], padding_[0], stride_[0]); - // check outputH & outputW - getOutput().setFrameHeight(outputH_); - getOutput().setFrameWidth(outputW_); - - // if the batchSize remains the same, set isSelectAlgo_ true. - // Otherwise, set isSelectAlgo_ false and select algo again. - isSelectAlgo_ = (batchSize == batchNum_); - batchNum_ = batchSize; - - size_t maxWorkSpace = 0; - for (size_t i = 0; i < inputLayers_.size(); i++) { - CHECK_EQ(inputLayers_[i]->getOutput().value->getWidth(), - (size_t)(channels_[i] * imageH_ * imageW_)); - - hl_tensor_reshape(inputDesc_[i], batchSize, channels_[i] / groups_[i], - imageH_, imageW_, channels_[i] * imageH_ * imageW_, - imageH_ * imageW_, imageW_, 1); - - hl_tensor_reshape(outputDesc_[i], batchSize, numFilters_ / groups_[i], - outputH_, outputW_, numFilters_ * outputH_ * outputW_, - outputH_ * outputW_, outputW_, 1); - - hl_reset_convolution_descriptor(convDesc_[i], inputDesc_[i], - filterDesc_[i], paddingY_[i], - padding_[i], strideY_[i], stride_[i]); - - inputOffset_[i] = (channels_[i] / groups_[i]) * imageH_ * imageW_; - outputOffset_[i] = (numFilters_ / groups_[i]) * outputH_ * outputW_; - - if (!isSelectAlgo_) { - hl_conv_workspace(inputDesc_[i], outputDesc_[i], filterDesc_[i], - convDesc_[i], &fwdAlgo_[i], &fwdLimitBytes_[i], - &bwdDataAlgo_[i], &bwdDataLimitBytes_[i], - &bwdFilterAlgo_[i], &bwdFilterLimitBytes_[i]); - - maxWorkSpace = std::max(fwdLimitBytes_[i], bwdDataLimitBytes_[i]); - maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_[i]); - - VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_[i] - << " / " << bwdDataAlgo_[i] - << " / " << bwdFilterAlgo_[i]; - } - } - - if (!isSelectAlgo_) { - allocConvWorkSpace(maxWorkSpace); - } - - isSelectAlgo_ = true; -} - void CudnnConvLayer::forward(PassType passType) { Layer::forward(passType); - int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); - reshape(batchSize); - resetOutput(batchSize, outputH_ * outputW_ * numFilters_); + + int batchSize = getInput(0).getBatchSize(); + resetOutput(batchSize, calOutputSize()); for (size_t i = 0; i != inputLayers_.size(); ++i) { - REGISTER_TIMER_INFO("CudnnConvFwTimer", getName().c_str()); - for (int g = 0; g < groups_[i]; ++g) { - real *inputData = getInputValue(i)->getData() + inputOffset_[i] * g; - real *wgtData = weights_[i]->getW()->getData() + weightOffset_[i] * g; - real *outData = getOutputValue()->getData() + outputOffset_[i] * g; - hl_convolution_forward(inputDesc_[i], inputData, outputDesc_[i], - outData, filterDesc_[i], wgtData, - convDesc_[i], workSpace_[g], - fwdLimitBytes_[i], fwdAlgo_[i]); - } + projections_[i]->forward(&getInput(i), &getOutput(), passType); } if (biases_) { REGISTER_TIMER_INFO("CudnnConvBiasTimer", getName().c_str()); - addBiases(); - } - - forwardActivation(); -} - -void CudnnConvLayer::addBiases() { - if (sharedBiases_) { + int batchSize = inputLayers_[0]->getOutputValue()->getHeight(); + hl_tensor_reshape(outputDesc_, batchSize, numFilters_ / groups_[0], + outputH_[0], outputW_[0], numFilters_ * outputH_[0] * outputW_[0], + outputH_[0] * outputW_[0], outputW_[0], 1); + outputOffset_ = getOutputValue()->getWidth() / groups_[0]; for (int g = 0; g < groups_[0]; ++g) { real *biasData = biases_->getW()->getData() + biasOffset_ * g; - real *outData = getOutputValue()->getData() + outputOffset_[0] * g; + real *outData = getOutputValue()->getData() + outputOffset_ * g; hl_convolution_forward_add_bias(biasDesc_, biasData, - outputDesc_[0], outData); + outputDesc_, outData); } - } else { - LOG(FATAL) << "Not supported"; } -} -void CudnnConvLayer::bpropBiases() { - if (sharedBiases_) { - for (int g = 0; g < groups_[0]; ++g) { - real *biasGrad = biases_->getWGrad()->getData() + biasOffset_ * g; - real *outGrad = getOutputGrad()->getData() + outputOffset_[0] * g; - hl_convolution_backward_bias(biasDesc_, biasGrad, - outputDesc_[0], outGrad); - } - } else { - LOG(FATAL) << "Not supported"; - } + forwardActivation(); } void CudnnConvLayer::backward(const UpdateCallback &callback) { @@ -238,52 +87,23 @@ void CudnnConvLayer::backward(const UpdateCallback &callback) { if (biases_ && biases_->getWGrad()) { REGISTER_TIMER_INFO("CudnnConvBpBiasTimer", getName().c_str()); - bpropBiases(); + for (int g = 0; g < groups_[0]; ++g) { + real *biasGrad = biases_->getWGrad()->getData() + biasOffset_ * g; + real *outGrad = getOutputGrad()->getData() + outputOffset_ * g; + hl_convolution_backward_bias(biasDesc_, biasGrad, outputDesc_, outGrad); + } biases_->getParameterPtr()->incUpdate(callback); } for (size_t i = 0; i != inputLayers_.size(); ++i) { - REGISTER_TIMER_INFO("CudnnConvBpTimer", getName().c_str()); - for (int g = 0; g < groups_[i]; ++g) { - real *outGrad = getOutputGrad()->getData() + outputOffset_[i] * g; - if (weights_[i]->getWGrad()) { - real *inputData = getInputValue(i)->getData() + inputOffset_[i] * g; - real *weightGrad = - weights_[i]->getWGrad()->getData() + weightOffset_[i] * g; - hl_convolution_backward_filter( - inputDesc_[i], inputData, outputDesc_[i], outGrad, filterDesc_[i], - weightGrad, convDesc_[i], workSpace_[g], bwdFilterLimitBytes_[i], - bwdFilterAlgo_[i]); - } - - MatrixPtr preGrad = getInputGrad(i); - if (NULL != preGrad) { - real *inputGrad = preGrad->getData() + inputOffset_[i] * g; - real *wgtData = weights_[i]->getW()->getData() + weightOffset_[i] * g; - hl_convolution_backward_data( - inputDesc_[i], inputGrad, outputDesc_[i], outGrad, filterDesc_[i], - wgtData, convDesc_[i], workSpace_[g], bwdDataLimitBytes_[i], - bwdDataAlgo_[i]); - } - } - weights_[i]->getParameterPtr()->incUpdate(callback); + projections_[i]->backward(callback); } } CudnnConvLayer::~CudnnConvLayer() { - if (biasDesc_) { + if (biases_) { hl_destroy_tensor_descriptor(biasDesc_); - } - - for (size_t i = 0; i < inputDesc_.size(); i++) { - hl_destroy_tensor_descriptor(inputDesc_[i]); - hl_destroy_tensor_descriptor(outputDesc_[i]); - hl_destroy_filter_descriptor(filterDesc_[i]); - hl_destroy_convolution_descriptor(convDesc_[i]); - } - if (workSpaceInBytes_ != 0) { - hl_free_mem_device(workSpaceData_); - workSpaceInBytes_ = 0; + hl_destroy_tensor_descriptor(outputDesc_); } } diff --git a/paddle/gserver/layers/CudnnConvLayer.h b/paddle/gserver/layers/CudnnConvLayer.h index a6dadba10daa4..6390d96315cc4 100644 --- a/paddle/gserver/layers/CudnnConvLayer.h +++ b/paddle/gserver/layers/CudnnConvLayer.h @@ -17,12 +17,13 @@ limitations under the License. */ #include "ConvBaseLayer.h" #include "paddle/math/Matrix.h" +#include "Projection.h" #include namespace paddle { /** - * @brief A subclass of ConvBaseLayer by cuDNN implementation. It only + * @brief A 2-dimension conv layer implemented by cuDNN. It only * supports GPU mode. We automatic select CudnnConvLayer for GPU * mode and ExpandConvLayer for CPU mode if you set type of "conv". * User also can specfiy type of "exconv" or "cudnn_conv" for @@ -31,81 +32,21 @@ namespace paddle { * The config file api is img_conv_layer. */ class CudnnConvLayer : public ConvBaseLayer { -private: - /// resize Cudnn workspace size - void allocConvWorkSpace(size_t maxWorkSpace); - protected: - int imageH_, imageW_, outputH_, outputW_; - /// Cudnn tensor descriptor for bias. + std::vector> projConf_; + std::vector> projections_; + hl_tensor_descriptor biasDesc_; - /// Cudnn tensor descriptor for input. - std::vector inputDesc_; - /// Cudnn tensor descriptor for output. - std::vector outputDesc_; - /// Cudnn tensor descriptor for filter. - std::vector filterDesc_; - /// Cudnn tensor descriptor for a convolution operation. - std::vector convDesc_; - /// One sample offset of input data. - IntV inputOffset_; - /// One sample offset of output data. - IntV outputOffset_; - /// One group offset of weight. - IntV weightOffset_; - /// One group offset of bias. + hl_tensor_descriptor outputDesc_; int biasOffset_; - - /// Save the algorithm for forward convolution, which is obtained by cudnn - /// api to search the best suited algorithm. - std::vector fwdAlgo_; - /// Save the algorithm for computing convolution gradient with respect to - /// filter coefficients. - std::vector bwdFilterAlgo_; - /// Save the algorithm for computing convolution gradient with respect to - /// the output. - std::vector bwdDataAlgo_; - /// Amount of GPU memory needed as workspace to be able to execute a - /// forward convolution with the specified algo. - std::vector fwdLimitBytes_; - /// Amount of GPU memory needed as workspace to be able to execute a - /// backwardFilter with the specified algo. - std::vector bwdFilterLimitBytes_; - /// Amount of GPU memory needed as workspace to be able to execute a - /// backwardData with the specified algo. - std::vector bwdDataLimitBytes_; - - /// Device work space address for each group. - std::vector workSpace_; - /// Max number of groups. - int maxGroups_; - /// Total work space address in device for all groups. - void* workSpaceData_; - /// Size of total work space. - size_t workSpaceInBytes_; - - /// Is or not select conv algorihtm. - bool isSelectAlgo_; - - /// batchNum is used to record batch size. If the batch size is changed, - /// the selection algorithm will be called. - int batchNum_; + int outputOffset_; public: explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {} ~CudnnConvLayer(); - /** - * Intialization. Initialize member variables and create tenor descriptor. - */ bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); - /** - * Reshape is done each forward. Reshape tensor decriptor - * inputDesc_, outputDesc_, convDesc_. And search the faster algo - * or the fastest algo within a given memeory limit. - */ - void reshape(int batchSize); void forward(PassType passType); void backward(const UpdateCallback& callback); void addBiases(); diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp index df79c3e3037cf..80a6a62b5c0de 100644 --- a/paddle/gserver/layers/ExpandConvLayer.cpp +++ b/paddle/gserver/layers/ExpandConvLayer.cpp @@ -37,32 +37,29 @@ bool ExpandConvLayer::init(const LayerMap &layerMap, caffeMode_ = conf.caffe_mode(); } + /* initialize the weightList */ + CHECK(inputLayers_.size() == parameters_.size()); + for (size_t i = 0; i < inputLayers_.size(); i++) { + size_t height, width; + height = filterPixels_[i] * filterChannels_[i]; + width = numFilters_; + + // create a new weight + CHECK_EQ(parameters_[i]->getSize(), width * height); + Weight* w = new Weight(height, width, parameters_[i]); + weights_.emplace_back(w); + } + return true; } -size_t ExpandConvLayer::getSize() { +size_t ExpandConvLayer::getOutputSize() { CHECK_NE(inputLayers_.size(), 0UL); - imgSizeH_.clear(); - imgSizeW_.clear(); - outputH_.clear(); - outputW_.clear(); + size_t layerSize = ConvBaseLayer::calOutputSize(); subN_.clear(); - size_t layerSize = 0; for (size_t i = 0; i < inputLayers_.size(); i++) { - imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight()); - imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth()); - if (imgSizeH_[i] == 0) imgSizeH_[i] = imgSize_[i]; - if (imgSizeW_[i] == 0) imgSizeW_[i] = imgSize_[i]; - outputH_.push_back( - outputSize(imgSizeH_[i], filterSize_[i], padding_[i], stride_[i])); - outputW_.push_back( - outputSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i])); subN_.push_back(outputH_[i] * outputW_[i]); - CHECK(layerSize == 0 || subN_[i] * size_t(numFilters_) == layerSize); - layerSize = subN_[i] * numFilters_; } - getOutput().setFrameHeight(outputH_[0]); - getOutput().setFrameWidth(outputW_[0]); return layerSize; } @@ -119,7 +116,7 @@ void ExpandConvLayer::expandFwdOnce(MatrixPtr image, int inIdx, int startIdx) { } void ExpandConvLayer::addSharedBias() { - size_t mapW = getSize() / numFilters_; + size_t mapW = getOutputValue()->getWidth() / numFilters_; size_t mapH = getOutputValue()->getElementCnt() / mapW; MatrixPtr out = Matrix::create(getOutputValue()->getData(), mapH, mapW, false, useGpu_); @@ -158,7 +155,7 @@ void ExpandConvLayer::forward(PassType passType) { * transOutValue correspond sample to one row */ int batchSize = inputLayers_[0]->getOutputValue()->getWidth(); batchSize = inputLayers_[0]->getOutputValue()->getHeight(); - resetOutput(batchSize, getSize()); + resetOutput(batchSize, getOutputSize()); MatrixPtr image = nullptr; for (size_t i = 0; i != inputLayers_.size(); ++i) { @@ -183,7 +180,7 @@ void ExpandConvLayer::forward(PassType passType) { } void ExpandConvLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) { - size_t mapW = getSize() / numFilters_; + size_t mapW = v->getWidth() / numFilters_; size_t mapH = v->getElementCnt() / mapW; MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_); diff --git a/paddle/gserver/layers/ExpandConvLayer.h b/paddle/gserver/layers/ExpandConvLayer.h index fc3d69b1b7d14..030a3ba397ff4 100644 --- a/paddle/gserver/layers/ExpandConvLayer.h +++ b/paddle/gserver/layers/ExpandConvLayer.h @@ -37,14 +37,6 @@ class ExpandConvLayer : public ConvBaseLayer { IntV subN_; /// subK_ = channels_ * filterPixels_ * groups_. IntV subK_; - /// The spatial dimensions of height of input feature map. - IntV imgSizeH_; - /// The spatial dimensions of width of input feature map. - IntV imgSizeW_; - /// The spatial dimensions of height of output feature map. - IntV outputH_; - /// The spatial dimensions of width of output feature map. - IntV outputW_; /// Expand one sample at a time. shape: /// (numChannels * filterPixels_, outputSizeH * outputSizeW) MatrixPtr expandInput_; @@ -58,7 +50,7 @@ class ExpandConvLayer : public ConvBaseLayer { bool init(const LayerMap& layerMap, const ParameterMap& parameterMap); - size_t getSize(); + size_t getOutputSize(); /** * Create or resize expandInput_. diff --git a/paddle/gserver/layers/MixedLayer.cpp b/paddle/gserver/layers/MixedLayer.cpp index 054ddd3a228ed..26b1360290ffb 100644 --- a/paddle/gserver/layers/MixedLayer.cpp +++ b/paddle/gserver/layers/MixedLayer.cpp @@ -41,9 +41,13 @@ bool MixedLayer::init(const LayerMap& layerMap, } operators_.emplace_back(Operator::create(operator_conf, useGpu_)); } + /* initialize biases_ */ if (biasParameter_.get() != NULL) { - biases_ = std::unique_ptr(new Weight(1, getSize(), biasParameter_)); + sharedBias_ = config_.shared_biases(); + size_t psize = config_.bias_size(); + biases_ = std::unique_ptr( + new Weight(1, psize, biasParameter_)); } return true; @@ -119,12 +123,6 @@ void MixedLayer::forward(PassType passType) { MatrixPtr outV = getOutputValue(); - /* add the bias-vector */ - if (biases_.get() != NULL) { - REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); - outV->addBias(*(biases_->getW()), 1); - } - for (size_t i = 0; i != inputLayers_.size(); ++i) { if (projections_[i]) { projections_[i]->forward(&getInput(i), &output_, passType); @@ -140,6 +138,12 @@ void MixedLayer::forward(PassType passType) { op->forward(ins, &output_, passType); } + /* add the bias-vector */ + if (biases_.get() != NULL) { + REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str()); + outV->addBias(*(biases_->getW()), 1, sharedBias_); + } + /* activation */ { REGISTER_TIMER_INFO("FwAtvTimer", getName().c_str()); forwardActivation(); @@ -154,7 +158,7 @@ void MixedLayer::backward(const UpdateCallback& callback) { if (biases_ && biases_->getWGrad()) { REGISTER_TIMER_INFO("BpBiasTimer", getName().c_str()); - biases_->getWGrad()->collectBias(*getOutputGrad(), 1); + biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBias_); /* Increasing the number of gradient */ biases_->getParameterPtr()->incUpdate(callback); diff --git a/paddle/gserver/layers/MixedLayer.h b/paddle/gserver/layers/MixedLayer.h index 9bac1355bd21f..5842e51e1d79d 100644 --- a/paddle/gserver/layers/MixedLayer.h +++ b/paddle/gserver/layers/MixedLayer.h @@ -58,5 +58,6 @@ class MixedLayer : public Layer { /// the matrix size of projection state std::vector projectionStateMatrixSize_; std::unique_ptr biases_; + bool sharedBias_; }; } // namespace paddle diff --git a/paddle/gserver/tests/LayerGradUtil.cpp b/paddle/gserver/tests/LayerGradUtil.cpp index 552a6c5b41c7f..bc7bee0e4bbc8 100644 --- a/paddle/gserver/tests/LayerGradUtil.cpp +++ b/paddle/gserver/tests/LayerGradUtil.cpp @@ -669,12 +669,14 @@ void testLayerGrad(TestConfig testConf, string testLayerName, size_t batchSize, void testProjectionGrad(ProjectionConfig conf, InputType inputType, size_t parameterSize, size_t batchSize, bool useGpu, - bool testState) { + bool testState, int biasSize, bool sharedBias) { TestConfig config; conf.set_name(conf.type()); config.layerConfig.set_type("mixed"); config.layerConfig.set_size(conf.output_size()); - config.biasSize = config.layerConfig.size(); + config.biasSize = biasSize == 0 ? config.layerConfig.size() : biasSize; + config.layerConfig.set_bias_size(config.biasSize); + config.layerConfig.set_shared_biases(sharedBias); config.inputDefs.push_back( {inputType, "layer_0", conf.input_size(), parameterSize}); *config.layerConfig.add_inputs()->mutable_proj_conf() = conf; diff --git a/paddle/gserver/tests/LayerGradUtil.h b/paddle/gserver/tests/LayerGradUtil.h index 1e608dc0620ab..3b9ec803959b3 100644 --- a/paddle/gserver/tests/LayerGradUtil.h +++ b/paddle/gserver/tests/LayerGradUtil.h @@ -217,7 +217,8 @@ void testLayerGrad(TestConfig testConf, string testLayerName, size_t batchSize, void testProjectionGrad(ProjectionConfig conf, InputType inputType, size_t parameterSize, size_t batchSize, bool useGpu, - bool testState = false); + bool testState = false, int biasSize = 0, + bool sharedBias = false); void testOperatorGrad(TestConfig& config, OperatorConfig& operatorConf, size_t batchSize, bool useGpu, bool testState = false); diff --git a/paddle/gserver/tests/img_conv_a.conf b/paddle/gserver/tests/img_conv_a.conf new file mode 100644 index 0000000000000..940589ed9ac24 --- /dev/null +++ b/paddle/gserver/tests/img_conv_a.conf @@ -0,0 +1,39 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) +data = data_layer(name ="input", size=8*16*16) +conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) +conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=False, + act=ReluActivation()) + +concat = concat_layer(input=[conv1, conv2]) + +conv = img_conv_layer(input=data, filter_size=1, filter_size_y=1, + num_channels=8, + num_filters=16, stride=1, + bias_attr=True, + act=LinearActivation()) + +outputs(concat, conv) diff --git a/paddle/gserver/tests/img_conv_b.conf b/paddle/gserver/tests/img_conv_b.conf new file mode 100644 index 0000000000000..8ca9c94541504 --- /dev/null +++ b/paddle/gserver/tests/img_conv_b.conf @@ -0,0 +1,32 @@ +#edit-mode: -*- python -*- +# Copyright (c) 2016 Baidu, Inc. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.trainer_config_helpers import * + +settings(batch_size=10) +data = data_layer(name ="input", size=8*16*16) +proj1 = conv_projection(input=data, filter_size=1, filter_size_y=1, + num_channels=8, num_filters=16, stride=1) +proj2 = conv_projection(input=data, filter_size=1, filter_size_y=1, + num_channels=8, num_filters=16, stride=1) +concat = concat_layer(input=[proj1, proj2], bias_attr=False, act=ReluActivation()) + +proj = conv_projection(input=data, filter_size=1, filter_size_y=1, + num_channels=8, num_filters=16, stride=1) + +with mixed_layer(bias_attr=True, act=LinearActivation()) as conv: + conv += proj + +outputs(concat, conv) diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp index eab9bf84141a2..bf2c2e0499941 100644 --- a/paddle/gserver/tests/test_LayerGrad.cpp +++ b/paddle/gserver/tests/test_LayerGrad.cpp @@ -134,6 +134,45 @@ TEST(Projection, identity) { } } + +#ifndef PADDLE_ONLY_CPU +TEST(Projection, conv) { + const int NUM_FILTERS = 16; + const int FILTER_SIZE = 2; + const int FILTER_SIZE_Y = 3; + const int CHANNELS = 3; + const int IMAGE_SIZE = 16; + + ProjectionConfig conf; + conf.set_type("conv"); + conf.set_num_filters(NUM_FILTERS); + + ConvConfig* conv = conf.mutable_conv_conf(); + conv->set_filter_size(FILTER_SIZE); + conv->set_filter_size_y(FILTER_SIZE_Y); + conv->set_channels(CHANNELS); + conv->set_padding(0); + conv->set_padding_y(1); + conv->set_stride(2); + conv->set_stride_y(2); + conv->set_groups(1); + conv->set_filter_channels(conv->channels() / conv->groups()); + conv->set_img_size(IMAGE_SIZE); + int outputSize = (2 * conv->padding() + conv->img_size() - + conv->filter_size()) / conv->stride() + 1; + int outputSizeY = (2 * conv->padding_y() + conv->img_size() - + conv->filter_size_y()) / conv->stride_y() + 1; + conv->set_output_x(outputSize); + conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS); + conf.set_output_size(outputSize * outputSizeY * NUM_FILTERS); + + testProjectionGrad(conf, INPUT_DATA, + /* parameterSize */ NUM_FILTERS * CHANNELS * FILTER_SIZE * FILTER_SIZE_Y, + /* batchSize */ 100, true, false, NUM_FILTERS, true); +} +#endif + + TEST(Layer, concat) { TestConfig config; config.biasSize = 0; diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp index b3ef53067301b..8d3eac5aca8d1 100644 --- a/paddle/gserver/tests/test_NetworkCompare.cpp +++ b/paddle/gserver/tests/test_NetworkCompare.cpp @@ -236,6 +236,15 @@ TEST(Compare, img_pool) { compareNetwork(config_file_a, config_file_b); FLAGS_use_gpu = useGpu; } + +TEST(Compare, img_conv) { + std::string config_file_a = "./gserver/tests/img_conv_a.conf"; + std::string config_file_b = "./gserver/tests/img_conv_b.conf"; + bool useGpu = FLAGS_use_gpu; + FLAGS_use_gpu = true; + compareNetwork(config_file_a, config_file_b); + FLAGS_use_gpu = useGpu; +} #endif diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp index 843eabc97d642..aaeae98f0d28b 100644 --- a/paddle/math/Matrix.cpp +++ b/paddle/math/Matrix.cpp @@ -340,6 +340,15 @@ void GpuMatrix::addBias(Matrix& b, real scale) { BaseMatrix::addBias(b, scale); } +void GpuMatrix::addSharedBias(Matrix& b, real scale) { + CHECK(b.getHeight() == 1) << "the Bias should be a vector"; + CHECK_LE(b.getWidth(), getWidth()); + CHECK_EQ(getWidth() % b.getWidth(), 0UL); + hl_matrix_add_shared_bias(getData(), b.getData(), b.getWidth(), + getHeight(), getWidth(), scale); +} + + void GpuMatrix::collectBias(Matrix& a, real scale) { CHECK_EQ(getHeight(), (size_t)1); CHECK_EQ(width_, a.getWidth()); @@ -354,6 +363,14 @@ void GpuMatrix::collectBias(Matrix& a, real scale) { } } +void GpuMatrix::collectSharedBias(Matrix& a, real scale) { + CHECK_EQ(getHeight(), (size_t)1); + CHECK_EQ(a.getWidth() % getWidth(), 0UL); + hl_matrix_collect_shared_bias(getData(), a.getData(), getWidth(), + a.getHeight(), a.getWidth(), scale); +} + + void GpuMatrix::sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode) { @@ -1983,6 +2000,24 @@ void CpuMatrix::addBias(Matrix& b, real scale) { } } +void CpuMatrix::addSharedBias(Matrix& b, real scale) { + CHECK_EQ(b.getHeight(), (size_t)1); + real* aData = getData(); + real* bData = b.getData(); + size_t numSamples = getHeight(); + size_t channel = b.getWidth(); + CHECK_EQ(getWidth() % channel, 0UL); + size_t dim = getWidth() / channel; + + for (size_t i = 0; i < numSamples; i++) { + for (size_t c = 0; c < channel; c++) { + for (size_t j = 0; j < dim; j++) { + aData[i * getStride() + c * dim + j] += scale * bData[c]; + } + } + } +} + void CpuMatrix::collectBias(Matrix& a, real scale) { CHECK_EQ(getHeight(), (size_t)1); CHECK_EQ(width_, a.getWidth()); @@ -2000,6 +2035,23 @@ void CpuMatrix::collectBias(Matrix& a, real scale) { } } +void CpuMatrix::collectSharedBias(Matrix& a, real scale) { + CHECK_EQ(getHeight(), (size_t)1); + real* B = getData(); + real* A = a.getData(); + size_t numSamples = a.getHeight(); + size_t channel = getWidth(); + CHECK_EQ(a.getWidth() % channel, 0UL); + size_t dim = a.getWidth() / channel; + for (size_t i = 0; i < numSamples; i++) { + for (size_t c = 0; c < channel; c++) { + for (size_t j = 0; j < dim; j++) { + B[c] += scale * A[i * channel * dim + c * dim + j]; + } + } + } +} + void CpuMatrix::sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode) { diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h index 047c76a8604cc..52cbed528ca8b 100644 --- a/paddle/math/Matrix.h +++ b/paddle/math/Matrix.h @@ -343,11 +343,35 @@ class Matrix : public BaseMatrix { LOG(FATAL) << "Not implemented"; } + virtual void addSharedBias(Matrix& b, real scale) { + LOG(FATAL) << "Not implemented"; + } + + virtual void addBias(Matrix& b, real scale, bool sharedBias) { + if (!sharedBias) { + addBias(b, scale); + } else { + addSharedBias(b, scale); + } + } + /// add each sample from a to this. virtual void collectBias(Matrix& a, real scale) { LOG(FATAL) << "Not implemented"; } + virtual void collectSharedBias(Matrix& a, real scale) { + LOG(FATAL) << "Not implemented"; + } + + virtual void collectBias(Matrix& a, real scale, bool sharedBias) { + if (!sharedBias) { + collectBias(a, scale); + } else { + collectSharedBias(a, scale); + } + } + virtual void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode) { LOG(FATAL) << "Not implemented"; @@ -1021,6 +1045,7 @@ class GpuMatrix : public Matrix { /// add b to each sample of this. void addBias(Matrix& b, real scale); + void addSharedBias(Matrix& b, real scale); /** * @code @@ -1028,6 +1053,7 @@ class GpuMatrix : public Matrix { * @endcode */ void collectBias(Matrix& a, real scale); + void collectSharedBias(Matrix& a, real scale); void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode); @@ -1341,9 +1367,11 @@ class CpuMatrix : public Matrix { public: /// add b to each sample of this. void addBias(Matrix& b, real scale); + void addSharedBias(Matrix& b, real scale); /// add each sample of a to this. void collectBias(Matrix& a, real scale); + void collectSharedBias(Matrix& a, real scale); void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode); diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp index ac160479a9dfc..0ddf7e0dfc386 100644 --- a/paddle/math/tests/test_matrixCompare.cpp +++ b/paddle/math/tests/test_matrixCompare.cpp @@ -21,6 +21,8 @@ limitations under the License. */ #include "paddle/math/SparseMatrix.h" #include #include "paddle/gserver/tests/TestUtil.h" +#include "paddle/utils/Stat.h" + using namespace paddle; // NOLINT using namespace std; // NOLINT @@ -2071,6 +2073,60 @@ TEST(Matrix, MaxOutFwdBwd) { } } +void testAddSharedBias(int numSamples, int dim, int channel) { + MatrixPtr cpuData = std::make_shared(numSamples, dim); + MatrixPtr gpuData = std::make_shared(numSamples, dim); + + MatrixPtr cpuBias = std::make_shared(1, channel); + MatrixPtr gpuBias = std::make_shared(1, channel); + + cpuData->randomizeUniform(); + gpuData->copyFrom(*cpuData); + cpuBias->randomizeUniform(); + gpuBias->copyFrom(*cpuBias); + + cpuData->addSharedBias(*cpuBias, 1.0); + gpuData->addSharedBias(*gpuBias, 1.0); + + MatrixPtr check = std::make_shared(numSamples, dim); + check->copyFrom(*gpuData); + MatrixCheckErr(*cpuData, *check); +} + +void testCollectSharedBias(int numSamples, int dim, int channel) { + MatrixPtr cpuData = std::make_shared(numSamples, dim); + MatrixPtr gpuData = std::make_shared(numSamples, dim); + + MatrixPtr cpuBias = std::make_shared(1, channel); + MatrixPtr gpuBias = std::make_shared(1, channel); + + cpuData->randomizeUniform(); + gpuData->copyFrom(*cpuData); + cpuBias->randomizeUniform(); + gpuBias->copyFrom(*cpuBias); + + cpuBias->collectSharedBias(*cpuData, 1.0); + gpuBias->collectSharedBias(*gpuData, 1.0); + + MatrixPtr check = std::make_shared(1, channel); + check->copyFrom(*gpuBias); + MatrixCheckErr(*cpuBias, *check); +} + + +TEST(Matrix, sharedBias) { + for (auto numSamples : {1, 100, 520}) { + for (auto dim : {100 * 16, 100 * 32}) { + for (auto channel : {8, 16}) { + VLOG(3) << " numSamples=" << numSamples << " dim=" << dim + << " channel=" << channel; + testAddSharedBias(numSamples, dim, channel); + testCollectSharedBias(numSamples, dim, channel); + } + } + } +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); initMain(argc, argv); diff --git a/paddle/trainer/CMakeLists.txt b/paddle/trainer/CMakeLists.txt index 08b411d2ccbae..06c019f0a9775 100644 --- a/paddle/trainer/CMakeLists.txt +++ b/paddle/trainer/CMakeLists.txt @@ -7,6 +7,7 @@ set(TRAINER_SOURCES Tester.cpp Trainer.cpp TrainerInternal.cpp + TrainerBenchmark.cpp ThreadParameterUpdater.cpp TrainerInternalConfig.cpp TrainerConfigHelper.cpp) diff --git a/paddle/trainer/Trainer.h b/paddle/trainer/Trainer.h index 4f4811a139e74..7762722456c44 100644 --- a/paddle/trainer/Trainer.h +++ b/paddle/trainer/Trainer.h @@ -99,6 +99,7 @@ class Trainer { void startTrainPass(); void finishTrainPass(); void trainOneDataBatch(DataBatch& dataBatch); + void time(); /** * given a dataBatch and the current parameter value diff --git a/paddle/trainer/TrainerBenchmark.cpp b/paddle/trainer/TrainerBenchmark.cpp new file mode 100644 index 0000000000000..54862e95b4a73 --- /dev/null +++ b/paddle/trainer/TrainerBenchmark.cpp @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#undef PADDLE_DISABLE_TIMER + +#include "Trainer.h" +#include "paddle/utils/Stat.h" +#include "paddle/utils/Util.h" + +P_DECLARE_int32(test_period); + +P_DEFINE_bool(feed_data, false, "Wether to read data from DataProvider."); + +namespace paddle { + +void Trainer::time() { + startTrain(); + + trainerInternal_.getParameterUpdater()->startPass(); + evaluator_->start(); + + DataBatch dataBatch; + int32_t batchSize = config_->getOptConfig().batch_size(); + int32_t num = dataProvider_->getNextBatch(batchSize, &dataBatch); + CHECK_EQ(num, batchSize) << "The sample number is less than batch size " + << num << " != " << batchSize; + + CHECK(dataBatch.getSize()) << "No data from data provider"; + + std::vector outputs; + // burning time + LOG(INFO) << "Burning time..."; + for (int n = 0; n < 10; ++n) { + trainerInternal_.trainOneBatch(n, dataBatch, &outputs); + } + LOG(INFO) << "Burning time end."; + + for (int n = 0; n < FLAGS_test_period; n++) { + if (FLAGS_feed_data) { + REGISTER_TIMER("GetData"); + num = dataProvider_->getNextBatch(batchSize, &dataBatch); + } + + if (num != batchSize) { + break; + } + + { + REGISTER_TIMER("FwdBwd"); + trainerInternal_.trainOneBatch(n, dataBatch, &outputs); + } + } + globalStat.setThreadInfo(true); + globalStat.printSegTimerStatus(); + globalStat.reset(); + + finishTrain(); +} + +} // namespace paddle diff --git a/paddle/trainer/TrainerMain.cpp b/paddle/trainer/TrainerMain.cpp index 94266639f94ad..a486cc383ace6 100644 --- a/paddle/trainer/TrainerMain.cpp +++ b/paddle/trainer/TrainerMain.cpp @@ -103,6 +103,8 @@ int main(int argc, char** argv) { trainer.checkGradient(); } else if (FLAGS_job == "test") { trainer.test(); + } else if (FLAGS_job == "time") { + trainer.time(); } else { LOG(FATAL) << "Unknown job type: " << FLAGS_job; } diff --git a/proto/ModelConfig.proto.m4 b/proto/ModelConfig.proto.m4 index 70c1f8d563238..79e76b6bf1bdd 100644 --- a/proto/ModelConfig.proto.m4 +++ b/proto/ModelConfig.proto.m4 @@ -255,7 +255,7 @@ sinclude(`ModelConfigLayer.proto.m4') // (which is how convnets are usually trained). Setting this to // false will untie the biases, yielding a separate bias for // every location at which the filter is applied. - optional bool shared_biases = 8; + optional bool shared_biases = 8 [default = false]; // Valid values are ones that divide the area of the output // grid in this convolutional layer. For example if this layer @@ -379,6 +379,9 @@ sinclude(`ModelConfigLayer.proto.m4') // use to compute moving mean and variance. optional real moving_average_fraction = 47 [default = 0.9]; + + // bias size + optional uint32 bias_size = 48 [default = 0]; } message EvaluatorConfig { diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py index fe8a5e5d48767..e9098943165fd 100644 --- a/python/paddle/trainer/config_parser.py +++ b/python/paddle/trainer/config_parser.py @@ -632,6 +632,44 @@ def calc_parameter_dims(self, input_size, output_size): _total_pad = 0 +@config_class +class ConvProjection(Projection): + type = 'conv' + + def __init__( + self, + input_layer_name, + num_filters=None, + conv_conf=None, + **xargs): + super(ConvProjection, self).__init__(input_layer_name, **xargs) + + if num_filters is not None: + self.proj_conf.num_filters = num_filters + + parse_conv(conv_conf, + input_layer_name, + self.proj_conf.conv_conf) + # TODO: support rectangle input + self.proj_conf.output_size = (self.proj_conf.conv_conf.output_x ** 2) * num_filters + + def calc_output_size(self, input_layer_config): + return self.proj_conf.output_size + + def calc_parameter_size(self, input_size, output_size): + co = self.proj_conf.num_filters + ci = self.proj_conf.conv_conf.channels + fh = self.proj_conf.conv_conf.filter_size + fw = self.proj_conf.conv_conf.filter_size_y + return co * ci * fh * fw + + def calc_bias_size(self): + return self.proj_conf.num_filters + + def calc_parameter_dims(self, input_size, output_size): + return None + + # Define a operator for mixed layer @config_class class Operator(Cfg): @@ -2528,8 +2566,15 @@ def __init__( record_operator_conf = self.config.operator_confs.add() record_operator_conf.CopyFrom(operator_conf) + psize = self.config.size + if isinstance(self.inputs[0], ConvProjection): + self.config.shared_biases = True + psize = 0 + for input in self.inputs: + psize += input.calc_bias_size() - self.create_bias_parameter(bias, self.config.size) + self.config.bias_size = psize + self.create_bias_parameter(bias, psize) if error_clipping_threshold is not None: self.config.error_clipping_threshold = error_clipping_threshold @@ -2547,8 +2592,10 @@ def __init__( self, name, inputs, + bias=False, **xargs): config_assert(inputs, 'inputs cannot be empty') + config_assert(not bias, 'ConcatenateLayer cannot support bias.') super(ConcatenateLayer, self).__init__( name, 'concat', 0, inputs=inputs, **xargs) size = 0 @@ -2567,10 +2614,19 @@ def __init__( self, name, inputs, + bias=False, **xargs): config_assert(inputs, 'inputs cannot be empty') super(ConcatenateLayer2, self).__init__( name, 'concat2', 0, inputs=inputs, **xargs) + + if isinstance(self.inputs[0], ConvProjection): + for input_index in xrange(len(self.inputs) - 1): + input = self.inputs[input_index + 1] + config_assert(isinstance(input, ConvProjection), + "The first input of ConcatenateLayer2 is ConvProjection, " + "the other inputs should also be ConvProjection.") + size = 0 for input_index in xrange(len(self.inputs)): input_layer = self.get_input_layer(input_index) @@ -2596,6 +2652,16 @@ def __init__( input.proj_conf.output_size) self.create_input_parameter(input_index, psize, dims) + psize = self.config.size + if isinstance(self.inputs[0], ConvProjection): + self.config.shared_biases = True + psize = 0 + for input in self.inputs: + psize += input.calc_bias_size() + + self.config.bias_size = psize + self.create_bias_parameter(bias, psize) + @config_layer('recurrent') class RecurrentLayer(LayerBase): def __init__( diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py index 7df9108ae82a4..9a23c02431d18 100644 --- a/python/paddle/trainer_config_helpers/layers.py +++ b/python/paddle/trainer_config_helpers/layers.py @@ -34,7 +34,7 @@ "table_projection", "mixed_layer", "data_layer", "embedding_layer", "fc_layer", "grumemory", "pooling_layer", "lstmemory", "last_seq", "first_seq", - "cos_sim", "hsigmoid", + "cos_sim", "hsigmoid", "conv_projection", "regression_cost", 'classification_cost', "LayerOutput", 'img_conv_layer', 'img_pool_layer', 'batch_norm_layer', 'img_cmrnorm_layer', 'addto_layer', @@ -1984,7 +1984,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None, @wrap_act_default(act=IdentityActivation()) @wrap_name_default("concat") @layer_support() -def concat_layer(input, act=None, name=None, layer_attr=None): +def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None): """ Concat all input vector into one huge vector. Inputs can be list of LayerOutput or list of projection. @@ -2043,10 +2043,14 @@ def __reduce_concat_type__(a, b): layer_type = (LayerType.CONCAT_LAYER if is_concat_layer else LayerType.CONCAT_PROJ_LAYER) + if layer_type == LayerType.CONCAT_LAYER: + assert not bias_attr + Layer( name=name, type=layer_type, inputs=[x.name for x in input] if is_concat_layer else input, active_type=act.name, + bias=ParamAttr.to_bias(bias_attr), **ExtraLayerAttribute.to_kwargs(layer_attr) ) @@ -2950,6 +2954,103 @@ def conv_operator(img, filter, filter_size, num_filters, op.origin = [img, filter] return op +@wrap_param_attr_default() +def conv_projection(input, filter_size, num_filters, + num_channels=None, stride=1, padding=0, + filter_size_y=None, stride_y=None, padding_y=None, + groups=1, param_attr=None): + """ + ConvProjection with a layer as input. + It performs element-wise multiplication with weight. + + Different from img_conv_layer and conv_op, conv_projection is an Projection, + which can be used in mixed_layer and conat_layer. It use cudnn to implement + conv and only support GPU mode. + + The example usage is: + + .. code-block:: python + + proj = conv_projection(img=input1, + filter_size=3, + num_filters=64, + num_channels=64) + + :param input: input layer + :type input: LayerOutput + :param filter_size: The x dimension of a filter kernel. + :type filter_size: int + :param filter_size_y: The y dimension of a filter kernel. Since + PaddlePaddle now supports rectangular filters, + the filter's shape can be (filter_size, filter_size_y). + :type filter_size_y: int + :param num_filters: channel of output data. + :type num_filters: int + :param num_channel: channel of input data. + :type num_channel: int + :param stride: The x dimension of the stride. + :type stride: int + :param stride_y: The y dimension of the stride. + :type stride_y: int + :param padding: The x dimension of padding. + :type padding: int + :param padding_y: The y dimension of padding. + :type padding_y: int + :param groups: The group number. + :type groups: int + :param param_attr: Convolution param attribute. None means default attribute + :type param_attr: ParameterAttribute + :return: A DotMulProjection Object. + :rtype: DotMulProjection + """ + if num_channels is None: + assert input.num_filters is not None + num_channels = input.num_filters + + if filter_size_y is None: + if isinstance(filter_size, collections.Sequence): + assert len(filter_size) == 2 + filter_size, filter_size_y = filter_size + else: + filter_size_y = filter_size + + if stride_y is None: + if isinstance(stride, collections.Sequence): + assert len(stride) == 2 + stride, stride_y = stride + else: + stride_y = stride + + if padding_y is None: + if isinstance(padding, collections.Sequence): + assert len(padding) == 2 + padding, padding_y = padding + else: + padding_y = padding + + if param_attr.attr.get('initial_smart'): + # special initial for conv layers. + init_w = (2.0 / (filter_size ** 2 * num_channels)) ** 0.5 + param_attr.attr["initial_mean"] = 0.0 + param_attr.attr["initial_std"] = init_w + param_attr.attr["initial_strategy"] = 0 + param_attr.attr["initial_smart"] = False + + proj = ConvProjection(input_layer_name=input.name, + num_filters=num_filters, + conv_conf=Conv(filter_size=filter_size, + padding=padding, + stride=stride, + channels=num_channels, + filter_size_y=filter_size_y, + padding_y=padding_y, + stride_y=stride_y, + groups=groups), + **param_attr.attr) + + proj.origin = input + return proj + @wrap_name_default() @layer_support() diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py index 65512b327cdc6..bce88f93626ec 100644 --- a/python/paddle/trainer_config_helpers/networks.py +++ b/python/paddle/trainer_config_helpers/networks.py @@ -29,7 +29,7 @@ "img_conv_bn_pool", 'dropout_layer', 'lstmemory_group', 'lstmemory_unit', 'small_vgg', 'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru', 'simple_attention', - 'text_conv_pool', + 'simple_gru2', 'bidirectional_gru', 'text_conv_pool', 'bidirectional_lstm', 'inputs', 'outputs'] @@ -811,22 +811,37 @@ def simple_gru(input, gru_layer_attr=None ): """ - simple_gru is also a recurrent layer group version Gated Recurrent Unit as - gru_group. The difference only lies in implemention details. + You maybe see gru_step_layer, grumemory in layers.py, gru_unit, gru_group, + simple_gru in network.py. The reason why there are so many interfaces is + that we have two ways to implement recurrent neural network. One way is to + use one complete layer to implement rnn (including simple rnn, gru and lstm) + with multiple time steps, such as recurrent_layer, lstmemory, grumemory. But, + the multiplication operation :math:`W x_t` is not computed in these layers. + See details in their interfaces in layers.py. + The other implementation is to use an recurrent group which can ensemble a + series of layers to compute rnn step by step. This way is flexible for + attenion mechanism or other complex connections. + + - gru_step_layer: only compute rnn by one step. It needs an memory as input + and can be used in recurrent group. + - gru_unit: a wrapper of gru_step_layer with memory. + - gru_group: a GRU cell implemented by a combination of multiple layers in + recurrent group. + But :math:`W x_t` is not done in group. + - gru_memory: a GRU cell implemented by one layer, which does same calculation + with gru_group and is faster than gru_group. + - simple_gru: a complete GRU implementation inlcuding :math:`W x_t` and + gru_group. :math:`W` contains :math:`W_r`, :math:`W_z` and :math:`W`, see + formula in grumemory. + The computational speed is that, grumemory is relatively better than gru_group, and gru_group is relatively better than simple_gru. - simple_gru does exactly the same calculation as the grumemory layer does. - Please see grumemory in layers.py for more detail about the maths. - The example usage is: .. code-block:: python - gru = gur_group(input=[layer1], - size=256, - act=TanhActivation(), - gate_act=SigmoidActivation()) + gru = simple_gru(input=[layer1], size=256) :param input: input layer name. :type input: LayerOutput @@ -863,6 +878,132 @@ def simple_gru(input, gru_layer_attr=gru_layer_attr) +@wrap_name_default('simple_gru2') +def simple_gru2(input, + size, + name=None, + reverse=False, + mixed_param_attr=None, + mixed_bias_attr=None, + gru_param_attr=None, + gru_bias_attr=None, + act=None, + gate_act=None, + mixed_layer_attr=None, + gru_cell_attr=None + ): + """ + simple_gru2 is the same with simple_gru, but using grumemory instead + Please see grumemory in layers.py for more detail about the maths. + simple_gru2 is faster than simple_gru. + + The example usage is: + + .. code-block:: python + + gru = simple_gru2(input=[layer1], size=256) + + :param input: input layer name. + :type input: LayerOutput + :param name: name of the gru group. + :type name: basestring + :param size: hidden size of the gru. + :type size: int + :param reverse: whether to process the input data in a reverse order + :type reverse: bool + :param act: type of the activiation + :type act: BaseActivation + :param gate_act: type of the gate activiation + :type gate_act: BaseActivation + :param gru_bias_attr: bias. False means no bias, None means default bias. + :type gru_bias_attr: ParameterAttribute|False + :param gru_layer_attr: Extra parameter attribute of the gru layer. + :type gru_layer_attr: ParameterAttribute|False + :return: the gru group. + :rtype: LayerOutput + """ + with mixed_layer(name='%s_transform' % name, + size=size * 3, + bias_attr=mixed_bias_attr, + layer_attr=mixed_layer_attr) as m: + m += full_matrix_projection(input=input, param_attr=mixed_param_attr) + + return grumemory(name=name, + size=size, + input=m, + reverse=reverse, + bias_attr=gru_bias_attr, + param_attr=gru_param_attr, + act=act, + gate_act=gate_act, + layer_attr=gru_cell_attr) + + +@wrap_name_default("bidirectional_gru") +def bidirectional_gru(input, size, name=None, return_seq=False, + fwd_mixed_param_attr=None, fwd_mixed_bias_attr=None, + fwd_gru_param_attr=None, fwd_gru_bias_attr=None, + fwd_act=None, fwd_gate_act=None, + fwd_mixed_layer_attr=None, fwd_gru_cell_attr=None, + + bwd_mixed_param_attr=None, bwd_mixed_bias_attr=None, + bwd_gru_param_attr=None, bwd_gru_bias_attr=None, + bwd_act=None, bwd_gate_act=None, + bwd_mixed_layer_attr=None, bwd_gru_cell_attr=None, + + last_seq_attr=None, first_seq_attr=None, + concat_attr=None, concat_act=None): + """ + A bidirectional_gru is a recurrent unit that iterates over the input + sequence both in forward and bardward orders, and then concatenate two + outputs to form a final output. However, concatenation of two outputs + is not the only way to form the final output, you can also, for example, + just add them together. + + The example usage is: + + .. code-block:: python + + bi_gru = bidirectional_gru(input=[input1], size=512) + + :param name: bidirectional gru layer name. + :type name: basestring + :param input: input layer. + :type input: LayerOutput + :param size: gru layer size. + :type size: int + :param return_seq: If set False, outputs of the last time step are + concatenated and returned. + If set True, the entire output sequences that are + processed in forward and backward directions are + concatenated and returned. + :type return_seq: bool + :return: LayerOutput object. + :rtype: LayerOutput + """ + args = locals() + + fw = simple_gru2(name='%s_fw' % name, input=input, size=size, + **dict((k[len('fwd_'):], v) for k, v in args.iteritems() + if k.startswith('fwd_'))) + + bw = simple_gru2(name="%s_bw" % name, input=input, size=size, + reverse=True, + **dict((k[len('bwd_'):], v) for k, v in args.iteritems() + if k.startswith('bwd_'))) + + if return_seq: + return concat_layer(name=name, input=[fw, bw], layer_attr=concat_attr, + act=concat_act) + else: + fw_seq = last_seq(name="%s_fw_last" % name, input=fw, + layer_attr=last_seq_attr) + bw_seq = first_seq(name="%s_bw_last" % name, input=bw, + layer_attr=first_seq_attr) + return concat_layer(name=name, input=[fw_seq, bw_seq], + layer_attr=concat_attr, act=concat_act) + + @wrap_name_default("bidirectional_lstm") def bidirectional_lstm(input, size, name=None, return_seq=False, fwd_mat_param_attr=None, fwd_bias_param_attr=None, @@ -893,7 +1034,7 @@ def bidirectional_lstm(input, size, name=None, return_seq=False, .. code-block:: python - lstm_step = bidirectional_lstm(input=[input1], size=512) + bi_lstm = bidirectional_lstm(input=[input1], size=512) :param name: bidirectional lstm layer name. :type name: basestring @@ -907,7 +1048,7 @@ def bidirectional_lstm(input, size, name=None, return_seq=False, processed in forward and backward directions are concatenated and returned. :type return_seq: bool - :return: lstm layer name. + :return: LayerOutput object accroding to the return_seq. :rtype: LayerOutput """ args = locals() diff --git a/python/paddle/trainer_config_helpers/tests/configs/check.md5 b/python/paddle/trainer_config_helpers/tests/configs/check.md5 index d1b22b34903df..72dfdad7bdd40 100644 --- a/python/paddle/trainer_config_helpers/tests/configs/check.md5 +++ b/python/paddle/trainer_config_helpers/tests/configs/check.md5 @@ -1,10 +1,11 @@ 86c0815275a9d5eb902e23c6a592f58a img_layers.protostr a5d9259ff1fd7ca23d0ef090052cb1f2 last_first_seq.protostr 9c038249ec8ff719753a746cdb04c026 layer_activations.protostr -5913f87b39cee3b2701fa158270aca26 projections.protostr +34e04043cbb12931c47fa44ec50eeffc projections.protostr 7334ba0a4544f0623231330fc51d390d shared_fc.protostr -8b8b6bb128a7dfcc937be86145f53e2f shared_lstm.protostr +bb8e233b05b8e07f9ed386b7aee4f2c6 shared_lstm.protostr 6b39e34beea8dfb782bee9bd3dea9eb5 simple_rnn_layers.protostr +f98e79e1630d5eb827c300e64836d269 test_bi_grumemory.protostr 0fc1409600f1a3301da994ab9d28b0bf test_cost_layers.protostr 6cd5f28a3416344f20120698470e0a4c test_cost_layers_with_weight.protostr 144bc6d3a509de74115fa623741797ed test_expand_layer.protostr @@ -15,7 +16,7 @@ d350bd91a0dc13e854b1364c3d9339c6 test_lstmemory_layer.protostr 5433ed33d4e7414eaf658f2a55946186 test_maxout.protostr 251a948ba41c1071afcd3d9cf9c233f7 test_ntm_layers.protostr e6ff04e70aea27c7b06d808cc49c9497 test_print_layer.protostr -2a75dd33b640c49a8821c2da6e574577 test_rnn_group.protostr +fded24727338fb8ce44d9951ed8aea08 test_rnn_group.protostr 67d6fde3afb54f389d0ce4ff14726fe1 test_sequence_pooling.protostr f586a548ef4350ba1ed47a81859a64cb unused_layers.protostr -8122477f4f65244580cec09edc590041 util_layers.protostr +f937a5a6e7e8864b4d8cf56b0f7c7f44 util_layers.protostr diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh index 4b1d2d3d41d52..6a31ceabdf36d 100755 --- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh +++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh @@ -9,7 +9,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer last_first_seq test_expand_layer test_ntm_layers test_hsigmoid img_layers util_layers simple_rnn_layers unused_layers test_cost_layers test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight -test_maxout) +test_maxout test_bi_grumemory) for conf in ${configs[*]} diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py b/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py new file mode 100644 index 0000000000000..ab9f7c4948b85 --- /dev/null +++ b/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py @@ -0,0 +1,10 @@ +from paddle.trainer_config_helpers import * + +settings( + batch_size=1000, + learning_rate=1e-4 +) + +din = data_layer(name='data', size=120) + +outputs(bidirectional_gru(input=din, size=40, return_seq=True))