diff --git a/doc/ui/cmd_argument/argument_outline.md b/doc/ui/cmd_argument/argument_outline.md
index 98dadc270dcac..d6cc2c6ed7cc1 100644
--- a/doc/ui/cmd_argument/argument_outline.md
+++ b/doc/ui/cmd_argument/argument_outline.md
@@ -183,7 +183,7 @@ It looks like there are a lot of arguments. However, most of them are for develo
-GPU | gpu_id |
+GPU | gpu_id |
√ | √ | √ | √ |
@@ -207,6 +207,11 @@ It looks like there are a lot of arguments. However, most of them are for develo
√ | √ | √ | √ |
+
+cudnn_conv_workspace_limit_in_mb |
+√ | √ | √ | √ |
+
+
RNN |
beam_size |
diff --git a/doc/ui/cmd_argument/detail_introduction.md b/doc/ui/cmd_argument/detail_introduction.md
index 0d0362d022a72..07608e5edf740 100644
--- a/doc/ui/cmd_argument/detail_introduction.md
+++ b/doc/ui/cmd_argument/detail_introduction.md
@@ -163,6 +163,10 @@
- Choose path to dynamic load NVIDIA CUDA library, for instance, /usr/local/cuda/lib64. [Default]: LD_LIBRARY_PATH
- type: string (default: "", null)
+* `--cudnn_conv_workspace_limit_in_mb`
+ - Specify cuDNN max workspace limit, in units MB, 4096MB=4GB by default.
+ - type: int32 (default: 4096MB=4GB)
+
## NLP: RNN/LSTM/GRU
* `--rnn_use_batch`
- Whether to use batch method for calculation in simple RecurrentLayer.
diff --git a/paddle/cuda/include/hl_device_functions.cuh b/paddle/cuda/include/hl_device_functions.cuh
index 88d950d6c1713..159c26f443cb1 100755
--- a/paddle/cuda/include/hl_device_functions.cuh
+++ b/paddle/cuda/include/hl_device_functions.cuh
@@ -48,5 +48,24 @@ inline __device__ double paddleAtomicAdd(double* address, double val) {
}
} // namespace paddle
+/**
+ * @brief sum reduction
+ *
+ * @param[in,out] smem input data, better to use __shared__ memory.
+ * @param[in] tid thread index.
+ * @param[in] threads the total thread number used to reduce,
+ * such as, blockDim.x.
+ *
+ * @return smem[0]: the sum of each elements in smem.
+ */
+__device__ __forceinline__
+void simpleReduce(real* smem, int tid, int threads) {
+ for (unsigned int s = threads / 2; s > 0; s >>= 1) {
+ if (tid < s) {
+ smem[tid] += smem[tid + s];
+ }
+ __syncthreads();
+ }
+}
#endif /* HL_DEVICE_FUNCTIONS_CUH_ */
diff --git a/paddle/cuda/include/hl_matrix.h b/paddle/cuda/include/hl_matrix.h
index 17419790471a7..71e8f8e3a60c9 100644
--- a/paddle/cuda/include/hl_matrix.h
+++ b/paddle/cuda/include/hl_matrix.h
@@ -229,4 +229,40 @@ extern void hl_cossim_derivative(real* grad,
int input2_height,
real scale);
+/**
+ * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel].
+ *
+ * @param[in] A_d input matrix (M x N).
+ * @param[in] B_d input matrix (1 x channel).
+ * @param[in] channel width of B.
+ * @param[in] dimM height of A.
+ * @param[in] dimN width of A.
+ * @param[in] scale scalar used for addition.
+ *
+ */
+extern void hl_matrix_add_shared_bias(real* A_d,
+ real* B_d,
+ const int channel,
+ const int dimM,
+ const int dimN,
+ real scale);
+
+/**
+ * @brief Matrix addition: A_d[i][j] += scale * B_d[j/channel].
+ *
+ * @param[in] B_d input matrix (1 x channel).
+ * @param[in] A_d input matrix (M x N).
+ * @param[in] channel width of B.
+ * @param[in] dimM height of A.
+ * @param[in] dimN width of A.
+ * @param[in] scale scalar used for addition.
+ *
+ */
+extern void hl_matrix_collect_shared_bias(real* B_d,
+ real* A_d,
+ const int channel,
+ const int dimM,
+ const int dimN,
+ real scale);
+
#endif /* HL_MATRIX_H_ */
diff --git a/paddle/cuda/include/stub/hl_matrix_stub.h b/paddle/cuda/include/stub/hl_matrix_stub.h
index f1f1020c84d46..e37b1275432ca 100644
--- a/paddle/cuda/include/stub/hl_matrix_stub.h
+++ b/paddle/cuda/include/stub/hl_matrix_stub.h
@@ -101,4 +101,17 @@ inline void hl_cossim_derivative(real* grad,
int input2_height,
real scale) {}
+inline void hl_matrix_add_shared_bias(real* A_d,
+ real* B_d,
+ const int channel,
+ const int dimM,
+ const int dimN,
+ real scale) {}
+
+inline void hl_matrix_collect_shared_bias(real* B_d,
+ real* A_d,
+ const int channel,
+ const int dimM,
+ const int dimN,
+ real scale) {}
#endif // HL_MATRIX_STUB_H_
diff --git a/paddle/cuda/src/hl_cuda_cudnn.cc b/paddle/cuda/src/hl_cuda_cudnn.cc
index b215c0f6e33a1..7810d0d10053d 100644
--- a/paddle/cuda/src/hl_cuda_cudnn.cc
+++ b/paddle/cuda/src/hl_cuda_cudnn.cc
@@ -20,6 +20,11 @@ limitations under the License. */
#include "hl_thread.ph"
#include "hl_dso_loader.h"
#include "paddle/utils/Logging.h"
+#include "paddle/utils/CommandLineParser.h"
+
+P_DEFINE_int32(cudnn_conv_workspace_limit_in_mb, 4096,
+ "Specify cuDNN max workspace limit, in units MB, "
+ "4096MB=4GB by default.");
namespace dynload {
@@ -242,7 +247,7 @@ void hl_conv_workspace(hl_tensor_descriptor input,
CHECK_NOTNULL(conv);
// Specify workspace limit directly
- size_t memoryLimitBytes = 8 * 1024 * 1024;
+ size_t memoryLimitBytes = (1LL << 20) * FLAGS_cudnn_conv_workspace_limit_in_mb;
// cudnn convolution forward configuration
cudnnTensorDescriptor_t fwd_src_desc = GET_TENSOR_DESCRIPTOR(input);
diff --git a/paddle/cuda/src/hl_cuda_matrix.cu b/paddle/cuda/src/hl_cuda_matrix.cu
index 067e68c41e119..3df9f63f9e4b7 100644
--- a/paddle/cuda/src/hl_cuda_matrix.cu
+++ b/paddle/cuda/src/hl_cuda_matrix.cu
@@ -20,6 +20,7 @@ limitations under the License. */
#include "hl_sequence.h"
#include "paddle/utils/Logging.h"
#include "hl_device_functions.cuh"
+#include "hl_gpu_matrix_kernel.cuh"
DEFINE_MATRIX_UNARY_OP(Zero, a = 0);
DEFINE_MATRIX_TERNARY_PARAMETER_OP(_add, TWO_PARAMETER, c = p1*a + p2*b);
@@ -673,3 +674,89 @@ void hl_cossim_derivative(real* grad,
input1_height, input2_height, scale);
CHECK_SYNC("hl_cossim_derivate failed");
}
+
+__global__ void KeMatrixAddSharedBias(real* A,
+ real* B,
+ const int channel,
+ const int M,
+ const int N,
+ real scale) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ int dim = N / channel;
+ if (index < M * N) {
+ int i = index % N;
+ i = i / dim;
+ A[index] += scale * B[i];
+ }
+}
+
+void hl_matrix_add_shared_bias(real* A_d,
+ real* B_d,
+ const int channel,
+ const int dimM,
+ const int dimN,
+ real scale) {
+ const int blocks = 512;
+ const int grids = DIVUP(dimM * dimN, blocks);
+ KeMatrixAddSharedBias<<>>
+ (A_d, B_d, channel, dimM, dimN, scale);
+ CHECK_SYNC("hl_matrix_add_shared_bias failed");
+}
+
+
+template
+__global__ void KeMatrixCollectSharedBias(real *B,
+ real *A,
+ const int channel,
+ const int M,
+ const int N,
+ const int dim,
+ const int limit,
+ real scale) {
+ if (dim < limit) {
+ int index = blockIdx.x * blockDim.x + threadIdx.x;
+ if (index < channel) {
+ real sum = 0.0;
+ for (int i = 0; i < M; ++i) {
+ for (int j = 0; j < dim; ++j) {
+ sum += A[i * N + index * dim + j];
+ }
+ }
+ B[index] += scale * sum;
+ }
+ } else {
+ const int tid = threadIdx.x;
+ const int bid = blockIdx.x;
+ __shared__ real smem[blockSize];
+ real sum = 0.0;
+ for (int j = 0; j < ((dim * M + blockSize - 1) / blockSize); ++j) {
+ int n = j * blockSize + tid;
+ int m = n / dim;
+ int w = n % dim;
+ smem[tid] = (m < M && w < dim) ? A[m * N + bid * dim + w] : 0.0;
+ __syncthreads();
+ simpleReduce(smem, tid, blockSize);
+ sum += smem[0];
+ }
+ if (tid == 0) {
+ B[bid] += scale * sum;
+ }
+ }
+}
+
+void hl_matrix_collect_shared_bias(real* B_d,
+ real* A_d,
+ const int channel,
+ const int dimM,
+ const int dimN,
+ real scale) {
+ const int dim = dimN / channel;
+ const int blocks = 256;
+ const int limit = 64;
+ int grids = (dimM * dim) < limit ? DIVUP(channel, blocks) : channel;
+
+ KeMatrixCollectSharedBias
+ <<< grids, blocks, 0, STREAM_DEFAULT>>>
+ (B_d, A_d, channel, dimM, dimN, dim, limit, scale);
+ CHECK_SYNC("hl_matrix_collect_shared_bias failed");
+}
diff --git a/paddle/cuda/src/hl_cuda_sparse.cuh b/paddle/cuda/src/hl_cuda_sparse.cuh
index c3b98f4ebc38d..9cf2d5a843343 100644
--- a/paddle/cuda/src/hl_cuda_sparse.cuh
+++ b/paddle/cuda/src/hl_cuda_sparse.cuh
@@ -908,24 +908,6 @@ int findIndex(int* indice, int num, int index) {
return (end - 1);
}
-/**
- * @brief sum reduction
- *
- * @param[in,out] smem input data, better to use __shared__ memory.
- * @param[in] tid local thread index.
- * @param[in] blockDimX the size of blockDim.x.
- *
- * note: return smem[0]: the sum of each elements of smem.
- */
-__device__ __forceinline__
-void reduce(real* smem, int tid, int blockDimX) {
- for (unsigned int s = blockDimX / 2; s > 0; s >>= 1) {
- if (tid < s) {
- smem[tid] += smem[tid + s];
- }
- __syncthreads();
- }
-}
/**
* @brief sum columns of csr sparse matrix (csr_val), then add to a_val.
diff --git a/paddle/gserver/layers/ConcatenateLayer.cpp b/paddle/gserver/layers/ConcatenateLayer.cpp
index 52a7cb6f777c3..bb6709b8df330 100644
--- a/paddle/gserver/layers/ConcatenateLayer.cpp
+++ b/paddle/gserver/layers/ConcatenateLayer.cpp
@@ -97,7 +97,8 @@ void ConcatenateLayer::backward(const UpdateCallback& callback) {
*/
class ConcatenateLayer2 : public Layer {
public:
- explicit ConcatenateLayer2(const LayerConfig& config) : Layer(config) {}
+ explicit ConcatenateLayer2(const LayerConfig& config) :
+ Layer(config) {}
~ConcatenateLayer2() {}
@@ -110,6 +111,8 @@ class ConcatenateLayer2 : public Layer {
std::vector> projections_;
std::vector projOutput_;
std::vector> projCol_;
+ bool sharedBias_;
+ std::unique_ptr biases_;
};
REGISTER_LAYER(concat2, ConcatenateLayer2);
@@ -119,7 +122,6 @@ bool ConcatenateLayer2::init(const LayerMap& layerMap,
/* Initialize the basic parent class */
if (!Layer::init(layerMap, parameterMap)) return false;
- CHECK(!biasParameter_);
CHECK_EQ(inputLayers_.size(), parameters_.size());
projections_.reserve(inputLayers_.size());
projCol_.reserve(inputLayers_.size());
@@ -137,6 +139,13 @@ bool ConcatenateLayer2::init(const LayerMap& layerMap,
}
CHECK_EQ(getSize(), endCol);
+ /* initialize biases_ */
+ if (biasParameter_.get() != NULL) {
+ sharedBias_ = config_.shared_biases();
+ size_t psize = config_.bias_size();
+ biases_ = std::unique_ptr(new Weight(1, psize, biasParameter_));
+ }
+
return true;
}
@@ -154,8 +163,17 @@ void ConcatenateLayer2::forward(PassType passType) {
projOutput_[i].grad = output_.grad->subColMatrix(startCol, endCol);
}
- for (size_t i = 0; i != inputLayers_.size(); ++i) {
- projections_[i]->forward(&getInput(i), &projOutput_[i], passType);
+ {
+ AsyncGpuBlock block;
+ for (size_t i = 0; i != inputLayers_.size(); ++i) {
+ projections_[i]->forward(&getInput(i), &projOutput_[i], passType);
+ }
+ }
+
+ /* add the bias-vector */
+ if (biases_) {
+ REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
+ output_.value->addBias(*(biases_->getW()), 1, sharedBias_);
}
/* activation */ {
@@ -170,6 +188,13 @@ void ConcatenateLayer2::backward(const UpdateCallback& callback) {
backwardActivation();
}
+ AsyncGpuBlock block;
+ if (biases_ && biases_->getWGrad()) {
+ REGISTER_TIMER_INFO("Concat2BpBiasTimer", getName().c_str());
+ biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBias_);
+ biases_->getParameterPtr()->incUpdate(callback);
+ }
+
for (size_t i = 0; i != inputLayers_.size(); ++i) {
if (projections_[i]) {
projections_[i]->backward(callback);
diff --git a/paddle/gserver/layers/ConvBaseLayer.cpp b/paddle/gserver/layers/ConvBaseLayer.cpp
index 9ed9572139dc8..040510b7ad211 100644
--- a/paddle/gserver/layers/ConvBaseLayer.cpp
+++ b/paddle/gserver/layers/ConvBaseLayer.cpp
@@ -35,25 +35,12 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
filterSizeY_.push_back(conf.filter_size_y());
filterPixels_.push_back(filterSize_.back() * filterSizeY_.back());
channels_.push_back(conf.channels());
- imgSize_.push_back(conf.img_size());
- imgPixels_.push_back(imgSize_.back() * imgSize_.back());
+ imgSizeH_.push_back(conf.img_size());
+ imgSizeW_.push_back(conf.img_size());
groups_.push_back(conf.groups());
filterChannels_.push_back(conf.filter_channels());
- outputX_.push_back(conf.output_x());
- outputs_.push_back(outputX_.back() * outputX_.back());
- }
-
- /* initialize the weightList */
- CHECK(inputLayers_.size() == parameters_.size());
- for (size_t i = 0; i < inputLayers_.size(); i++) {
- size_t height, width;
- height = filterPixels_[i] * filterChannels_[i];
- width = numFilters_;
-
- // create a new weight
- CHECK_EQ(parameters_[i]->getSize(), width * height);
- Weight* w = new Weight(height, width, parameters_[i]);
- weights_.emplace_back(w);
+ outputH_.push_back(conf.output_x());
+ outputW_.push_back(conf.output_x());
}
/* initialize the biases_ */
@@ -74,4 +61,34 @@ bool ConvBaseLayer::init(const LayerMap& layerMap,
return true;
}
+size_t ConvBaseLayer::calOutputSize() {
+ auto clearAndReserve = [this](IntV* vec) {
+ vec->clear();
+ vec->reserve(this->inputLayers_.size());
+ };
+ clearAndReserve(&imgSizeH_);
+ clearAndReserve(&imgSizeW_);
+ clearAndReserve(&outputH_);
+ clearAndReserve(&outputW_);
+ size_t layerSize = 0;
+ for (size_t i = 0; i < inputLayers_.size(); i++) {
+ imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight());
+ imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
+ if (imgSizeH_[i] == 0)
+ imgSizeH_[i] = config_.inputs(i).conv_conf().img_size();
+ if (imgSizeW_[i] == 0)
+ imgSizeW_[i] = config_.inputs(i).conv_conf().img_size();
+ outputH_.push_back(
+ outputSize(imgSizeH_[i], filterSizeY_[i], paddingY_[i], strideY_[i]));
+ outputW_.push_back(
+ outputSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i]));
+ CHECK_EQ(outputH_[i], outputH_[0]);
+ CHECK_EQ(outputW_[i], outputW_[0]);
+ }
+ getOutput().setFrameHeight(outputH_[0]);
+ getOutput().setFrameWidth(outputW_[0]);
+ layerSize = outputH_[0] * outputW_[0] * size_t(numFilters_);
+ return layerSize;
+}
+
} // namespace paddle
diff --git a/paddle/gserver/layers/ConvBaseLayer.h b/paddle/gserver/layers/ConvBaseLayer.h
index eaeaebf43be25..316514acf1a0d 100644
--- a/paddle/gserver/layers/ConvBaseLayer.h
+++ b/paddle/gserver/layers/ConvBaseLayer.h
@@ -43,19 +43,18 @@ class ConvBaseLayer : public Layer {
IntV filterSizeY_;
/// The spatial dimensions of the convolution input.
IntV channels_;
- /// The spatial dimensions of input feature map.
- IntV imgSize_;
- /// The total pixel size of input feature map.
- /// imgPixels_ = imgSizeX_ * imgSizeY_.
- IntV imgPixels_;
+ /// The spatial dimensions of input feature map height.
+ IntV imgSizeH_;
+ /// The spatial dimensions of input feature map width.
+ IntV imgSizeW_;
/// filterPixels_ = filterSizeX_ * filterSizeY_.
IntV filterPixels_;
/// filterChannels_ = channels_/groups_.
IntV filterChannels_;
- /// The spatial dimensions of output feature map.
- IntV outputX_;
- /// The spatial dimensions of output feature map.
- IntV outputs_;
+ /// The spatial dimensions of output feature map height.
+ IntV outputH_;
+ /// The spatial dimensions of output feature map width.
+ IntV outputW_;
/// Group size, refer to grouped convolution in
/// Alex Krizhevsky's paper: when group=2, the first half of the
/// filters are only connected to the first half of the input channels,
@@ -80,6 +79,13 @@ class ConvBaseLayer : public Layer {
virtual bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
+ /**
+ * imgSizeH_ and imgSizeW_ will be set according to the previous input layers
+ * in this function. Then it will calculate outputH_ and outputW_ and set them
+ * into output argument.
+ */
+ virtual size_t calOutputSize();
+
Weight& getWeight(int idx) { return *weights_[idx]; }
/**
diff --git a/paddle/gserver/layers/ConvProjection.cpp b/paddle/gserver/layers/ConvProjection.cpp
new file mode 100644
index 0000000000000..d1ce53fe26351
--- /dev/null
+++ b/paddle/gserver/layers/ConvProjection.cpp
@@ -0,0 +1,210 @@
+/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+
+#include "paddle/utils/Stat.h"
+#include "ConvProjection.h"
+
+namespace paddle {
+
+REGISTER_PROJECTION(conv, ConvProjection);
+
+ThreadLocalD> ConvProjection::convMem_;
+
+ConvProjection::ConvProjection(const ProjectionConfig& config,
+ ParameterPtr parameter, bool useGpu)
+ : Projection(config, parameter, useGpu) {
+
+ CHECK(useGpu); // only support GPU
+ getConvParams();
+ initCudnn();
+
+ size_t height = filterH_ * filterW_ * channels_ / groups_;
+ size_t width = numFilters_;
+ weight_.reset(new Weight(height, width, parameter));
+ weightOffset_ = height * width / groups_;
+}
+
+void ConvProjection::getConvParams() {
+ const ConvConfig &conf = config_.conv_conf();
+ paddingH_ = conf.padding_y();
+ paddingW_ = conf.padding();
+
+ strideH_ = conf.stride_y();
+ strideW_ = conf.stride();
+
+ filterH_ = conf.filter_size_y();
+ filterW_ = conf.filter_size();
+
+ configImgH_ = conf.img_size();
+ configImgW_ = conf.img_size();
+
+ channels_ = conf.channels();
+ numFilters_ = config_.num_filters();
+
+ groups_ = conf.groups();
+ CHECK_EQ(channels_ % groups_, 0);
+ CHECK_EQ(numFilters_ % groups_, 0);
+}
+
+void ConvProjection::initCudnn() {
+ hl_create_filter_descriptor(&filterDesc_, channels_, numFilters_,
+ filterH_, filterW_);
+ hl_create_tensor_descriptor(&inputDesc_);
+ hl_create_tensor_descriptor(&outputDesc_);
+ hl_create_convolution_descriptor(&convDesc_, inputDesc_, filterDesc_,
+ paddingH_, paddingW_, strideH_, strideW_);
+
+ // initialize all to default algorithms
+ fwdAlgo_ = 0;
+ bwdFilterAlgo_ = 0;
+ bwdDataAlgo_ = 0;
+ fwdLimitBytes_ = 0;
+ bwdDataLimitBytes_ = 0;
+ bwdFilterLimitBytes_ = 0;
+ workSpaceInBytes_ = 0;
+
+ batchNum_ = 0;
+ isSelectAlgo_ = false;
+}
+
+void ConvProjection::reshapeTensorDesc(int batchSize) {
+ hl_tensor_reshape(inputDesc_, batchSize, channels_, imageH_, imageW_,
+ channels_ * imageH_ * imageW_, imageH_ * imageW_,
+ imageW_, 1);
+ hl_reset_convolution_descriptor(convDesc_, inputDesc_, filterDesc_,
+ paddingH_, paddingW_, strideH_, strideW_);
+
+ // The stride between two consecutive images in ConvProjection may not be 1,
+ // for example, in the case of layer ConcatenateLayer2 with two
+ // ConvProjection, the stride is the output_size of layer ConcatenateLayer2.
+ // So the calculation of nStride is different from CudnnConvLayer.
+ // In fact, only "nStride = out_->value->getStride()" is ok.
+ size_t nStride = numFilters_ * outputH_ * outputW_;
+ if (out_->value->isContiguous()) {
+ CHECK_EQ(nStride, out_->value->getWidth());
+ } else {
+ nStride = out_->value->getStride();
+ }
+
+ hl_tensor_reshape(outputDesc_, batchSize, numFilters_, outputH_, outputW_,
+ nStride, outputH_ * outputW_, outputW_, 1);
+}
+
+void ConvProjection::reshape(int batchSize) {
+ size_t width = calOutputSize();
+ CHECK_EQ(width, out_->value->getWidth());
+
+ isSelectAlgo_ = (batchSize == batchNum_);
+ batchNum_ = batchSize;
+
+ if (!isSelectAlgo_) {
+ reshapeTensorDesc(batchSize);
+ hl_conv_workspace(inputDesc_, outputDesc_, filterDesc_,
+ convDesc_, &fwdAlgo_, &fwdLimitBytes_,
+ &bwdDataAlgo_, &bwdDataLimitBytes_,
+ &bwdFilterAlgo_, &bwdFilterLimitBytes_);
+
+ size_t maxWorkSpace = 0;
+ maxWorkSpace = std::max(fwdLimitBytes_, bwdDataLimitBytes_);
+ maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_);
+ workSpaceInBytes_ = maxWorkSpace;
+
+
+ VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_
+ << " / " << bwdDataAlgo_
+ << " / " << bwdFilterAlgo_;
+ }
+
+ isSelectAlgo_ = true;
+}
+
+void ConvProjection::forward() {
+ int batchSize = in_->value->getHeight();
+ reshape(batchSize);
+
+ void* workSpace = NULL;
+ if (workSpaceInBytes_ > 0) {
+ workSpace = getSpaceBytes(workSpaceInBytes_);
+ }
+
+ for (int g = 0; g < groups_; ++g) {
+ REGISTER_TIMER_INFO("CudnnConvFwTimer", getName().c_str());
+
+ real *inputData = in_->value->getData() + g * inputOffset_;
+ real *wgtData = weight_->getW()->getData() + g * weightOffset_;
+ real *outData = out_->value->getData() + g * outputOffset_;
+ hl_convolution_forward(inputDesc_, inputData, outputDesc_,
+ outData, filterDesc_, wgtData,
+ convDesc_, workSpace,
+ fwdLimitBytes_, fwdAlgo_);
+ }
+}
+
+void ConvProjection::backward(const UpdateCallback& callback) {
+ REGISTER_TIMER_INFO("CudnnConvBpTimer", getName().c_str());
+
+ void* workSpace = NULL;
+ if (workSpaceInBytes_ > 0) {
+ workSpace = getSpaceBytes(workSpaceInBytes_);
+ }
+
+ for (int g = 0; g < groups_; ++g) {
+ real *outGrad = out_->grad->getData() + g * outputOffset_;
+ if (weight_->getWGrad()) {
+ real *inputData = in_->value->getData() + g * inputOffset_;
+ real *weightGrad = weight_->getWGrad()->getData() + g * weightOffset_;
+ hl_convolution_backward_filter(
+ inputDesc_, inputData, outputDesc_, outGrad, filterDesc_,
+ weightGrad, convDesc_, workSpace, bwdFilterLimitBytes_,
+ bwdFilterAlgo_);
+ }
+
+ MatrixPtr preGrad = in_->grad;
+ if (NULL != preGrad) {
+ real *inputGrad = preGrad->getData() + g * inputOffset_;
+ real *wgtData = weight_->getW()->getData() + g* weightOffset_;
+ hl_convolution_backward_data(
+ inputDesc_, inputGrad, outputDesc_, outGrad, filterDesc_,
+ wgtData, convDesc_, workSpace, bwdDataLimitBytes_,
+ bwdDataAlgo_);
+ }
+ }
+
+ weight_->getParameterPtr()->incUpdate(callback);
+}
+
+void* ConvProjection::getSpaceBytes(size_t size) {
+ std::vector& convMem = *convMem_;
+ if (convMem.empty()) {
+ int numDevices = hl_get_device_count();
+ convMem.resize(numDevices);
+ }
+
+ int devId = hl_get_device();
+ MemoryHandle** localMem = &(convMem[devId]);
+ if (NULL == *localMem || size > (*localMem)->getAllocSize()) {
+ *localMem = new GpuMemoryHandle(size);
+ }
+ return (*localMem)->getBuf();
+}
+
+ConvProjection::~ConvProjection() {
+ hl_destroy_tensor_descriptor(inputDesc_);
+ hl_destroy_tensor_descriptor(outputDesc_);
+ hl_destroy_filter_descriptor(filterDesc_);
+ hl_destroy_convolution_descriptor(convDesc_);
+}
+
+} // namespace paddle
diff --git a/paddle/gserver/layers/ConvProjection.h b/paddle/gserver/layers/ConvProjection.h
new file mode 100644
index 0000000000000..41a100ac3c50f
--- /dev/null
+++ b/paddle/gserver/layers/ConvProjection.h
@@ -0,0 +1,125 @@
+/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+
+#pragma once
+
+#include "Projection.h"
+
+namespace paddle {
+
+/**
+ * @brief Convolution projection do the same calculation with CudnnConvLayer.
+ */
+class ConvProjection : public Projection {
+public:
+ /**
+ * Constructor.
+ */
+ ConvProjection(const ProjectionConfig& config, ParameterPtr parameter,
+ bool useGpu);
+
+ ~ConvProjection();
+
+ virtual void forward();
+ virtual void backward(const UpdateCallback& callback);
+
+protected:
+ void getConvParams();
+ void initCudnn();
+
+ void reshapeTensorDesc(int batchSize);
+ void reshape(int batchSize);
+
+ int outputSize(int imageSize, int filterSize, int padding, int stride) {
+ return (imageSize - filterSize + 2 * padding) / stride + 1;
+ }
+
+ size_t calOutputSize() {
+ imageH_ = in_->getFrameHeight();
+ imageW_ = in_->getFrameWidth();
+ if (imageH_ == 0) imageH_ = configImgH_;
+ if (imageW_ == 0) imageW_ = configImgW_;
+ outputH_ = outputSize(imageH_, filterH_, paddingH_, strideH_);
+ outputW_ = outputSize(imageW_, filterW_, paddingW_, strideW_);
+
+ const_cast(out_)->setFrameHeight(outputH_);
+ const_cast(out_)->setFrameWidth(outputW_);
+
+ inputOffset_ = (channels_ / groups_) * imageH_ * imageW_;
+ outputOffset_ = (numFilters_ / groups_) * outputH_ * outputW_;
+ return outputH_ * outputW_ * numFilters_;
+ }
+
+ static void* getSpaceBytes(size_t size);
+
+ /// imageH_ and imageW_ is calculated from the input layer.
+ int imageH_, imageW_;
+ /// configImgH_ and configImgW_ is obtained from config.
+ int configImgH_, configImgW_;
+ int outputH_, outputW_;
+ int channels_, numFilters_;
+ int paddingH_, paddingW_;
+ int strideH_, strideW_;
+ int filterH_, filterW_;
+ /// One group offset of input data.
+ int inputOffset_;
+ /// One group offset of output data.
+ int outputOffset_;
+ /// One group offset of weight.
+ int weightOffset_;
+ int groups_;
+
+ /// Cudnn tensor descriptor for input.
+ hl_tensor_descriptor inputDesc_;
+ /// Cudnn tensor descriptor for output.
+ hl_tensor_descriptor outputDesc_;
+ /// Cudnn tensor descriptor for filter.
+ hl_filter_descriptor filterDesc_;
+ /// Cudnn tensor descriptor for a convolution operation.
+ hl_convolution_descriptor convDesc_;
+
+ /// Record the algorithm for forward convolution, which is obtained by cudnn
+ /// api to search the best suited algorithm.
+ int fwdAlgo_;
+ /// Record the algorithm for computing convolution gradient with respect to
+ /// filter coefficients.
+ int bwdFilterAlgo_;
+ /// Record the algorithm for computing convolution gradient with respect to
+ /// the output.
+ int bwdDataAlgo_;
+ /// Amount of GPU memory needed as workspace to be able to execute a
+ /// forward convolution with the specified algo.
+ size_t fwdLimitBytes_;
+ /// Amount of GPU memory needed as workspace to be able to execute a
+ /// backwardFilter with the specified algo.
+ size_t bwdDataLimitBytes_;
+ /// Amount of GPU memory needed as workspace to be able to execute a
+ /// backwardData with the specified algo.
+ size_t bwdFilterLimitBytes_;
+ /// Size of total work space.
+ size_t workSpaceInBytes_;
+
+ /// Whether to call cuDNN api to choose conv algorithm.
+ bool isSelectAlgo_;
+ /// batchNum is used to record batch size. If the batch size is changed,
+ /// the selection algorithm will be called.
+ int batchNum_;
+ bool bias_;
+
+ std::unique_ptr weight_;
+ static ThreadLocalD> convMem_;
+};
+
+} // namespace paddle
diff --git a/paddle/gserver/layers/CudnnConvLayer.cpp b/paddle/gserver/layers/CudnnConvLayer.cpp
index 0f932f960f6ba..e77216f17c3be 100644
--- a/paddle/gserver/layers/CudnnConvLayer.cpp
+++ b/paddle/gserver/layers/CudnnConvLayer.cpp
@@ -22,215 +22,64 @@ REGISTER_LAYER(cudnn_conv, CudnnConvLayer);
bool CudnnConvLayer::init(const LayerMap &layerMap,
const ParameterMap ¶meterMap) {
- ConvBaseLayer::init(layerMap, parameterMap);
+ if (!ConvBaseLayer::init(layerMap, parameterMap)) return false;
CHECK(useGpu_) << "CudnnConvLayer only support gpu";
- maxGroups_ = 0;
- for (size_t i = 0; i < inputLayers_.size(); i++) {
- CHECK_EQ(channels_[i] % groups_[i], 0);
- CHECK_EQ(numFilters_ % groups_[i], 0);
-
- hl_filter_descriptor filter;
- hl_create_filter_descriptor(&filter, channels_[i] / groups_[i],
- numFilters_ / groups_[i], filterSizeY_[i],
- filterSize_[i]);
- filterDesc_.push_back(filter);
-
- hl_tensor_descriptor input;
- hl_create_tensor_descriptor(&input);
- inputDesc_.push_back(input);
-
- hl_tensor_descriptor output;
- int outputX =
- outputSize(imgSize_[i], filterSize_[i], padding_[i], stride_[i]);
- CHECK_EQ(outputX, outputX_[i]);
- hl_create_tensor_descriptor(&output);
- outputDesc_.push_back(output);
+ CHECK_EQ(inputLayers_.size(), parameters_.size());
+ projections_.reserve(inputLayers_.size());
+ projConf_.reserve(inputLayers_.size());
- hl_convolution_descriptor conv;
- hl_create_convolution_descriptor(&conv, input, filter, paddingY_[i],
- padding_[i], strideY_[i], stride_[i]);
- convDesc_.push_back(conv);
-
- weightOffset_.push_back((numFilters_ / groups_[i]) *
- (channels_[i] / groups_[i]) * filterPixels_[i]);
- inputOffset_.push_back((channels_[i] / groups_[i]) * imgSize_[i] *
- imgSize_[i]);
- outputOffset_.push_back((numFilters_ / groups_[i]) * outputX_[i] *
- outputX_[i]);
-
- // initialize all to default algorithms
- fwdAlgo_.push_back(0);
- bwdFilterAlgo_.push_back(0);
- bwdDataAlgo_.push_back(0);
- fwdLimitBytes_.push_back(0);
- bwdFilterLimitBytes_.push_back(0);
- bwdDataLimitBytes_.push_back(0);
-
- // cudnn streams per group equal to 1
- if (groups_[i] > maxGroups_) {
- maxGroups_ = groups_[i];
- }
- }
-
- workSpaceInBytes_ = 0;
- workSpaceData_ = NULL;
- for (int i = 0; i < maxGroups_; ++i) {
- workSpace_.push_back(NULL);
+ numFilters_ = config_.num_filters();
+ CHECK(config_.shared_biases());
+ for (size_t i = 0; i < inputLayers_.size(); i++) {
+ ProjectionConfig* conf = new ProjectionConfig();
+ conf->set_type("conv");
+ conf->set_num_filters(numFilters_);
+ conf->set_allocated_conv_conf(
+ config_.mutable_inputs(i)->mutable_conv_conf());
+ conf->set_input_size(getPrev(i)->getSize());
+ conf->set_output_size(getSize());
+ projConf_.emplace_back(conf);
+ projections_.emplace_back(Projection::create(*projConf_[i],
+ parameters_[i], useGpu_));
}
if (biases_.get() && sharedBiases_) {
hl_create_tensor_descriptor(&biasDesc_);
+ hl_create_tensor_descriptor(&outputDesc_);
hl_tensor_reshape(biasDesc_, 1, numFilters_ / groups_[0], 1, 1);
biasOffset_ = numFilters_ / groups_[0];
}
- batchNum_ = 0;
- isSelectAlgo_ = false;
return true;
}
-void CudnnConvLayer::allocConvWorkSpace(size_t maxWorkSpace) {
- size_t totalWorkSpace = maxWorkSpace * maxGroups_;
-
- if (totalWorkSpace > workSpaceInBytes_) {
- if (workSpaceInBytes_ != 0) {
- hl_free_mem_device(workSpaceData_);
- }
- // total amount of storage needed over all groups
- workSpaceData_ = hl_malloc_device(totalWorkSpace);
-
- // update work space address for each group
- for (int i = 0; i < maxGroups_; ++i) {
- workSpace_[i] = reinterpret_cast(workSpaceData_)
- + i * maxWorkSpace;
- }
- workSpaceInBytes_ = totalWorkSpace;
- }
-}
-
-void CudnnConvLayer::reshape(int batchSize) {
- CHECK_NE(inputLayers_.size(), 0UL);
- imageH_ = inputLayers_[0]->getOutput().getFrameHeight();
- imageW_ = inputLayers_[0]->getOutput().getFrameWidth();
- if (imageH_ == 0) imageH_ = imgSize_[0];
- if (imageW_ == 0) imageW_ = imgSize_[0];
-
- for (size_t i = 1; i < inputLayers_.size(); i++) {
- int imageH = inputLayers_[i]->getOutput().getFrameHeight();
- int imageW = inputLayers_[i]->getOutput().getFrameWidth();
- if (imageH) {
- CHECK_EQ(imageH_, imageH) << "Inputs must have same height.";
- }
- if (imageW) {
- CHECK_EQ(imageW_, imageW) << "Inputs must have same width.";
- }
- }
-
- outputH_ = outputSize(imageH_, filterSizeY_[0], paddingY_[0], strideY_[0]);
- outputW_ = outputSize(imageW_, filterSize_[0], padding_[0], stride_[0]);
- // check outputH & outputW
- getOutput().setFrameHeight(outputH_);
- getOutput().setFrameWidth(outputW_);
-
- // if the batchSize remains the same, set isSelectAlgo_ true.
- // Otherwise, set isSelectAlgo_ false and select algo again.
- isSelectAlgo_ = (batchSize == batchNum_);
- batchNum_ = batchSize;
-
- size_t maxWorkSpace = 0;
- for (size_t i = 0; i < inputLayers_.size(); i++) {
- CHECK_EQ(inputLayers_[i]->getOutput().value->getWidth(),
- (size_t)(channels_[i] * imageH_ * imageW_));
-
- hl_tensor_reshape(inputDesc_[i], batchSize, channels_[i] / groups_[i],
- imageH_, imageW_, channels_[i] * imageH_ * imageW_,
- imageH_ * imageW_, imageW_, 1);
-
- hl_tensor_reshape(outputDesc_[i], batchSize, numFilters_ / groups_[i],
- outputH_, outputW_, numFilters_ * outputH_ * outputW_,
- outputH_ * outputW_, outputW_, 1);
-
- hl_reset_convolution_descriptor(convDesc_[i], inputDesc_[i],
- filterDesc_[i], paddingY_[i],
- padding_[i], strideY_[i], stride_[i]);
-
- inputOffset_[i] = (channels_[i] / groups_[i]) * imageH_ * imageW_;
- outputOffset_[i] = (numFilters_ / groups_[i]) * outputH_ * outputW_;
-
- if (!isSelectAlgo_) {
- hl_conv_workspace(inputDesc_[i], outputDesc_[i], filterDesc_[i],
- convDesc_[i], &fwdAlgo_[i], &fwdLimitBytes_[i],
- &bwdDataAlgo_[i], &bwdDataLimitBytes_[i],
- &bwdFilterAlgo_[i], &bwdFilterLimitBytes_[i]);
-
- maxWorkSpace = std::max(fwdLimitBytes_[i], bwdDataLimitBytes_[i]);
- maxWorkSpace = std::max(maxWorkSpace, bwdFilterLimitBytes_[i]);
-
- VLOG(3) << getName() << " Fwd / BwdData / BwdFilter algo: " << fwdAlgo_[i]
- << " / " << bwdDataAlgo_[i]
- << " / " << bwdFilterAlgo_[i];
- }
- }
-
- if (!isSelectAlgo_) {
- allocConvWorkSpace(maxWorkSpace);
- }
-
- isSelectAlgo_ = true;
-}
-
void CudnnConvLayer::forward(PassType passType) {
Layer::forward(passType);
- int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
- reshape(batchSize);
- resetOutput(batchSize, outputH_ * outputW_ * numFilters_);
+
+ int batchSize = getInput(0).getBatchSize();
+ resetOutput(batchSize, calOutputSize());
for (size_t i = 0; i != inputLayers_.size(); ++i) {
- REGISTER_TIMER_INFO("CudnnConvFwTimer", getName().c_str());
- for (int g = 0; g < groups_[i]; ++g) {
- real *inputData = getInputValue(i)->getData() + inputOffset_[i] * g;
- real *wgtData = weights_[i]->getW()->getData() + weightOffset_[i] * g;
- real *outData = getOutputValue()->getData() + outputOffset_[i] * g;
- hl_convolution_forward(inputDesc_[i], inputData, outputDesc_[i],
- outData, filterDesc_[i], wgtData,
- convDesc_[i], workSpace_[g],
- fwdLimitBytes_[i], fwdAlgo_[i]);
- }
+ projections_[i]->forward(&getInput(i), &getOutput(), passType);
}
if (biases_) {
REGISTER_TIMER_INFO("CudnnConvBiasTimer", getName().c_str());
- addBiases();
- }
-
- forwardActivation();
-}
-
-void CudnnConvLayer::addBiases() {
- if (sharedBiases_) {
+ int batchSize = inputLayers_[0]->getOutputValue()->getHeight();
+ hl_tensor_reshape(outputDesc_, batchSize, numFilters_ / groups_[0],
+ outputH_[0], outputW_[0], numFilters_ * outputH_[0] * outputW_[0],
+ outputH_[0] * outputW_[0], outputW_[0], 1);
+ outputOffset_ = getOutputValue()->getWidth() / groups_[0];
for (int g = 0; g < groups_[0]; ++g) {
real *biasData = biases_->getW()->getData() + biasOffset_ * g;
- real *outData = getOutputValue()->getData() + outputOffset_[0] * g;
+ real *outData = getOutputValue()->getData() + outputOffset_ * g;
hl_convolution_forward_add_bias(biasDesc_, biasData,
- outputDesc_[0], outData);
+ outputDesc_, outData);
}
- } else {
- LOG(FATAL) << "Not supported";
}
-}
-void CudnnConvLayer::bpropBiases() {
- if (sharedBiases_) {
- for (int g = 0; g < groups_[0]; ++g) {
- real *biasGrad = biases_->getWGrad()->getData() + biasOffset_ * g;
- real *outGrad = getOutputGrad()->getData() + outputOffset_[0] * g;
- hl_convolution_backward_bias(biasDesc_, biasGrad,
- outputDesc_[0], outGrad);
- }
- } else {
- LOG(FATAL) << "Not supported";
- }
+ forwardActivation();
}
void CudnnConvLayer::backward(const UpdateCallback &callback) {
@@ -238,52 +87,23 @@ void CudnnConvLayer::backward(const UpdateCallback &callback) {
if (biases_ && biases_->getWGrad()) {
REGISTER_TIMER_INFO("CudnnConvBpBiasTimer", getName().c_str());
- bpropBiases();
+ for (int g = 0; g < groups_[0]; ++g) {
+ real *biasGrad = biases_->getWGrad()->getData() + biasOffset_ * g;
+ real *outGrad = getOutputGrad()->getData() + outputOffset_ * g;
+ hl_convolution_backward_bias(biasDesc_, biasGrad, outputDesc_, outGrad);
+ }
biases_->getParameterPtr()->incUpdate(callback);
}
for (size_t i = 0; i != inputLayers_.size(); ++i) {
- REGISTER_TIMER_INFO("CudnnConvBpTimer", getName().c_str());
- for (int g = 0; g < groups_[i]; ++g) {
- real *outGrad = getOutputGrad()->getData() + outputOffset_[i] * g;
- if (weights_[i]->getWGrad()) {
- real *inputData = getInputValue(i)->getData() + inputOffset_[i] * g;
- real *weightGrad =
- weights_[i]->getWGrad()->getData() + weightOffset_[i] * g;
- hl_convolution_backward_filter(
- inputDesc_[i], inputData, outputDesc_[i], outGrad, filterDesc_[i],
- weightGrad, convDesc_[i], workSpace_[g], bwdFilterLimitBytes_[i],
- bwdFilterAlgo_[i]);
- }
-
- MatrixPtr preGrad = getInputGrad(i);
- if (NULL != preGrad) {
- real *inputGrad = preGrad->getData() + inputOffset_[i] * g;
- real *wgtData = weights_[i]->getW()->getData() + weightOffset_[i] * g;
- hl_convolution_backward_data(
- inputDesc_[i], inputGrad, outputDesc_[i], outGrad, filterDesc_[i],
- wgtData, convDesc_[i], workSpace_[g], bwdDataLimitBytes_[i],
- bwdDataAlgo_[i]);
- }
- }
- weights_[i]->getParameterPtr()->incUpdate(callback);
+ projections_[i]->backward(callback);
}
}
CudnnConvLayer::~CudnnConvLayer() {
- if (biasDesc_) {
+ if (biases_) {
hl_destroy_tensor_descriptor(biasDesc_);
- }
-
- for (size_t i = 0; i < inputDesc_.size(); i++) {
- hl_destroy_tensor_descriptor(inputDesc_[i]);
- hl_destroy_tensor_descriptor(outputDesc_[i]);
- hl_destroy_filter_descriptor(filterDesc_[i]);
- hl_destroy_convolution_descriptor(convDesc_[i]);
- }
- if (workSpaceInBytes_ != 0) {
- hl_free_mem_device(workSpaceData_);
- workSpaceInBytes_ = 0;
+ hl_destroy_tensor_descriptor(outputDesc_);
}
}
diff --git a/paddle/gserver/layers/CudnnConvLayer.h b/paddle/gserver/layers/CudnnConvLayer.h
index a6dadba10daa4..6390d96315cc4 100644
--- a/paddle/gserver/layers/CudnnConvLayer.h
+++ b/paddle/gserver/layers/CudnnConvLayer.h
@@ -17,12 +17,13 @@ limitations under the License. */
#include "ConvBaseLayer.h"
#include "paddle/math/Matrix.h"
+#include "Projection.h"
#include
namespace paddle {
/**
- * @brief A subclass of ConvBaseLayer by cuDNN implementation. It only
+ * @brief A 2-dimension conv layer implemented by cuDNN. It only
* supports GPU mode. We automatic select CudnnConvLayer for GPU
* mode and ExpandConvLayer for CPU mode if you set type of "conv".
* User also can specfiy type of "exconv" or "cudnn_conv" for
@@ -31,81 +32,21 @@ namespace paddle {
* The config file api is img_conv_layer.
*/
class CudnnConvLayer : public ConvBaseLayer {
-private:
- /// resize Cudnn workspace size
- void allocConvWorkSpace(size_t maxWorkSpace);
-
protected:
- int imageH_, imageW_, outputH_, outputW_;
- /// Cudnn tensor descriptor for bias.
+ std::vector> projConf_;
+ std::vector> projections_;
+
hl_tensor_descriptor biasDesc_;
- /// Cudnn tensor descriptor for input.
- std::vector inputDesc_;
- /// Cudnn tensor descriptor for output.
- std::vector outputDesc_;
- /// Cudnn tensor descriptor for filter.
- std::vector filterDesc_;
- /// Cudnn tensor descriptor for a convolution operation.
- std::vector convDesc_;
- /// One sample offset of input data.
- IntV inputOffset_;
- /// One sample offset of output data.
- IntV outputOffset_;
- /// One group offset of weight.
- IntV weightOffset_;
- /// One group offset of bias.
+ hl_tensor_descriptor outputDesc_;
int biasOffset_;
-
- /// Save the algorithm for forward convolution, which is obtained by cudnn
- /// api to search the best suited algorithm.
- std::vector fwdAlgo_;
- /// Save the algorithm for computing convolution gradient with respect to
- /// filter coefficients.
- std::vector bwdFilterAlgo_;
- /// Save the algorithm for computing convolution gradient with respect to
- /// the output.
- std::vector bwdDataAlgo_;
- /// Amount of GPU memory needed as workspace to be able to execute a
- /// forward convolution with the specified algo.
- std::vector fwdLimitBytes_;
- /// Amount of GPU memory needed as workspace to be able to execute a
- /// backwardFilter with the specified algo.
- std::vector bwdFilterLimitBytes_;
- /// Amount of GPU memory needed as workspace to be able to execute a
- /// backwardData with the specified algo.
- std::vector bwdDataLimitBytes_;
-
- /// Device work space address for each group.
- std::vector workSpace_;
- /// Max number of groups.
- int maxGroups_;
- /// Total work space address in device for all groups.
- void* workSpaceData_;
- /// Size of total work space.
- size_t workSpaceInBytes_;
-
- /// Is or not select conv algorihtm.
- bool isSelectAlgo_;
-
- /// batchNum is used to record batch size. If the batch size is changed,
- /// the selection algorithm will be called.
- int batchNum_;
+ int outputOffset_;
public:
explicit CudnnConvLayer(const LayerConfig& config) : ConvBaseLayer(config) {}
~CudnnConvLayer();
- /**
- * Intialization. Initialize member variables and create tenor descriptor.
- */
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
- /**
- * Reshape is done each forward. Reshape tensor decriptor
- * inputDesc_, outputDesc_, convDesc_. And search the faster algo
- * or the fastest algo within a given memeory limit.
- */
- void reshape(int batchSize);
void forward(PassType passType);
void backward(const UpdateCallback& callback);
void addBiases();
diff --git a/paddle/gserver/layers/ExpandConvLayer.cpp b/paddle/gserver/layers/ExpandConvLayer.cpp
index df79c3e3037cf..80a6a62b5c0de 100644
--- a/paddle/gserver/layers/ExpandConvLayer.cpp
+++ b/paddle/gserver/layers/ExpandConvLayer.cpp
@@ -37,32 +37,29 @@ bool ExpandConvLayer::init(const LayerMap &layerMap,
caffeMode_ = conf.caffe_mode();
}
+ /* initialize the weightList */
+ CHECK(inputLayers_.size() == parameters_.size());
+ for (size_t i = 0; i < inputLayers_.size(); i++) {
+ size_t height, width;
+ height = filterPixels_[i] * filterChannels_[i];
+ width = numFilters_;
+
+ // create a new weight
+ CHECK_EQ(parameters_[i]->getSize(), width * height);
+ Weight* w = new Weight(height, width, parameters_[i]);
+ weights_.emplace_back(w);
+ }
+
return true;
}
-size_t ExpandConvLayer::getSize() {
+size_t ExpandConvLayer::getOutputSize() {
CHECK_NE(inputLayers_.size(), 0UL);
- imgSizeH_.clear();
- imgSizeW_.clear();
- outputH_.clear();
- outputW_.clear();
+ size_t layerSize = ConvBaseLayer::calOutputSize();
subN_.clear();
- size_t layerSize = 0;
for (size_t i = 0; i < inputLayers_.size(); i++) {
- imgSizeH_.push_back(inputLayers_[i]->getOutput().getFrameHeight());
- imgSizeW_.push_back(inputLayers_[i]->getOutput().getFrameWidth());
- if (imgSizeH_[i] == 0) imgSizeH_[i] = imgSize_[i];
- if (imgSizeW_[i] == 0) imgSizeW_[i] = imgSize_[i];
- outputH_.push_back(
- outputSize(imgSizeH_[i], filterSize_[i], padding_[i], stride_[i]));
- outputW_.push_back(
- outputSize(imgSizeW_[i], filterSize_[i], padding_[i], stride_[i]));
subN_.push_back(outputH_[i] * outputW_[i]);
- CHECK(layerSize == 0 || subN_[i] * size_t(numFilters_) == layerSize);
- layerSize = subN_[i] * numFilters_;
}
- getOutput().setFrameHeight(outputH_[0]);
- getOutput().setFrameWidth(outputW_[0]);
return layerSize;
}
@@ -119,7 +116,7 @@ void ExpandConvLayer::expandFwdOnce(MatrixPtr image, int inIdx, int startIdx) {
}
void ExpandConvLayer::addSharedBias() {
- size_t mapW = getSize() / numFilters_;
+ size_t mapW = getOutputValue()->getWidth() / numFilters_;
size_t mapH = getOutputValue()->getElementCnt() / mapW;
MatrixPtr out =
Matrix::create(getOutputValue()->getData(), mapH, mapW, false, useGpu_);
@@ -158,7 +155,7 @@ void ExpandConvLayer::forward(PassType passType) {
* transOutValue correspond sample to one row */
int batchSize = inputLayers_[0]->getOutputValue()->getWidth();
batchSize = inputLayers_[0]->getOutputValue()->getHeight();
- resetOutput(batchSize, getSize());
+ resetOutput(batchSize, getOutputSize());
MatrixPtr image = nullptr;
for (size_t i = 0; i != inputLayers_.size(); ++i) {
@@ -183,7 +180,7 @@ void ExpandConvLayer::forward(PassType passType) {
}
void ExpandConvLayer::bpropSharedBias(MatrixPtr biases, MatrixPtr v) {
- size_t mapW = getSize() / numFilters_;
+ size_t mapW = v->getWidth() / numFilters_;
size_t mapH = v->getElementCnt() / mapW;
MatrixPtr vTmp = Matrix::create(v->getData(), mapH, mapW, false, useGpu_);
diff --git a/paddle/gserver/layers/ExpandConvLayer.h b/paddle/gserver/layers/ExpandConvLayer.h
index fc3d69b1b7d14..030a3ba397ff4 100644
--- a/paddle/gserver/layers/ExpandConvLayer.h
+++ b/paddle/gserver/layers/ExpandConvLayer.h
@@ -37,14 +37,6 @@ class ExpandConvLayer : public ConvBaseLayer {
IntV subN_;
/// subK_ = channels_ * filterPixels_ * groups_.
IntV subK_;
- /// The spatial dimensions of height of input feature map.
- IntV imgSizeH_;
- /// The spatial dimensions of width of input feature map.
- IntV imgSizeW_;
- /// The spatial dimensions of height of output feature map.
- IntV outputH_;
- /// The spatial dimensions of width of output feature map.
- IntV outputW_;
/// Expand one sample at a time. shape:
/// (numChannels * filterPixels_, outputSizeH * outputSizeW)
MatrixPtr expandInput_;
@@ -58,7 +50,7 @@ class ExpandConvLayer : public ConvBaseLayer {
bool init(const LayerMap& layerMap, const ParameterMap& parameterMap);
- size_t getSize();
+ size_t getOutputSize();
/**
* Create or resize expandInput_.
diff --git a/paddle/gserver/layers/MixedLayer.cpp b/paddle/gserver/layers/MixedLayer.cpp
index 054ddd3a228ed..26b1360290ffb 100644
--- a/paddle/gserver/layers/MixedLayer.cpp
+++ b/paddle/gserver/layers/MixedLayer.cpp
@@ -41,9 +41,13 @@ bool MixedLayer::init(const LayerMap& layerMap,
}
operators_.emplace_back(Operator::create(operator_conf, useGpu_));
}
+
/* initialize biases_ */
if (biasParameter_.get() != NULL) {
- biases_ = std::unique_ptr(new Weight(1, getSize(), biasParameter_));
+ sharedBias_ = config_.shared_biases();
+ size_t psize = config_.bias_size();
+ biases_ = std::unique_ptr(
+ new Weight(1, psize, biasParameter_));
}
return true;
@@ -119,12 +123,6 @@ void MixedLayer::forward(PassType passType) {
MatrixPtr outV = getOutputValue();
- /* add the bias-vector */
- if (biases_.get() != NULL) {
- REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
- outV->addBias(*(biases_->getW()), 1);
- }
-
for (size_t i = 0; i != inputLayers_.size(); ++i) {
if (projections_[i]) {
projections_[i]->forward(&getInput(i), &output_, passType);
@@ -140,6 +138,12 @@ void MixedLayer::forward(PassType passType) {
op->forward(ins, &output_, passType);
}
+ /* add the bias-vector */
+ if (biases_.get() != NULL) {
+ REGISTER_TIMER_INFO("FwBiasTimer", getName().c_str());
+ outV->addBias(*(biases_->getW()), 1, sharedBias_);
+ }
+
/* activation */ {
REGISTER_TIMER_INFO("FwAtvTimer", getName().c_str());
forwardActivation();
@@ -154,7 +158,7 @@ void MixedLayer::backward(const UpdateCallback& callback) {
if (biases_ && biases_->getWGrad()) {
REGISTER_TIMER_INFO("BpBiasTimer", getName().c_str());
- biases_->getWGrad()->collectBias(*getOutputGrad(), 1);
+ biases_->getWGrad()->collectBias(*getOutputGrad(), 1, sharedBias_);
/* Increasing the number of gradient */
biases_->getParameterPtr()->incUpdate(callback);
diff --git a/paddle/gserver/layers/MixedLayer.h b/paddle/gserver/layers/MixedLayer.h
index 9bac1355bd21f..5842e51e1d79d 100644
--- a/paddle/gserver/layers/MixedLayer.h
+++ b/paddle/gserver/layers/MixedLayer.h
@@ -58,5 +58,6 @@ class MixedLayer : public Layer {
/// the matrix size of projection state
std::vector projectionStateMatrixSize_;
std::unique_ptr biases_;
+ bool sharedBias_;
};
} // namespace paddle
diff --git a/paddle/gserver/tests/LayerGradUtil.cpp b/paddle/gserver/tests/LayerGradUtil.cpp
index 552a6c5b41c7f..bc7bee0e4bbc8 100644
--- a/paddle/gserver/tests/LayerGradUtil.cpp
+++ b/paddle/gserver/tests/LayerGradUtil.cpp
@@ -669,12 +669,14 @@ void testLayerGrad(TestConfig testConf, string testLayerName, size_t batchSize,
void testProjectionGrad(ProjectionConfig conf, InputType inputType,
size_t parameterSize, size_t batchSize, bool useGpu,
- bool testState) {
+ bool testState, int biasSize, bool sharedBias) {
TestConfig config;
conf.set_name(conf.type());
config.layerConfig.set_type("mixed");
config.layerConfig.set_size(conf.output_size());
- config.biasSize = config.layerConfig.size();
+ config.biasSize = biasSize == 0 ? config.layerConfig.size() : biasSize;
+ config.layerConfig.set_bias_size(config.biasSize);
+ config.layerConfig.set_shared_biases(sharedBias);
config.inputDefs.push_back(
{inputType, "layer_0", conf.input_size(), parameterSize});
*config.layerConfig.add_inputs()->mutable_proj_conf() = conf;
diff --git a/paddle/gserver/tests/LayerGradUtil.h b/paddle/gserver/tests/LayerGradUtil.h
index 1e608dc0620ab..3b9ec803959b3 100644
--- a/paddle/gserver/tests/LayerGradUtil.h
+++ b/paddle/gserver/tests/LayerGradUtil.h
@@ -217,7 +217,8 @@ void testLayerGrad(TestConfig testConf, string testLayerName, size_t batchSize,
void testProjectionGrad(ProjectionConfig conf, InputType inputType,
size_t parameterSize, size_t batchSize, bool useGpu,
- bool testState = false);
+ bool testState = false, int biasSize = 0,
+ bool sharedBias = false);
void testOperatorGrad(TestConfig& config, OperatorConfig& operatorConf,
size_t batchSize, bool useGpu, bool testState = false);
diff --git a/paddle/gserver/tests/img_conv_a.conf b/paddle/gserver/tests/img_conv_a.conf
new file mode 100644
index 0000000000000..940589ed9ac24
--- /dev/null
+++ b/paddle/gserver/tests/img_conv_a.conf
@@ -0,0 +1,39 @@
+#edit-mode: -*- python -*-
+# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle.trainer_config_helpers import *
+
+settings(batch_size=10)
+data = data_layer(name ="input", size=8*16*16)
+conv1 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
+ num_channels=8,
+ num_filters=16, stride=1,
+ bias_attr=False,
+ act=ReluActivation())
+conv2 = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
+ num_channels=8,
+ num_filters=16, stride=1,
+ bias_attr=False,
+ act=ReluActivation())
+
+concat = concat_layer(input=[conv1, conv2])
+
+conv = img_conv_layer(input=data, filter_size=1, filter_size_y=1,
+ num_channels=8,
+ num_filters=16, stride=1,
+ bias_attr=True,
+ act=LinearActivation())
+
+outputs(concat, conv)
diff --git a/paddle/gserver/tests/img_conv_b.conf b/paddle/gserver/tests/img_conv_b.conf
new file mode 100644
index 0000000000000..8ca9c94541504
--- /dev/null
+++ b/paddle/gserver/tests/img_conv_b.conf
@@ -0,0 +1,32 @@
+#edit-mode: -*- python -*-
+# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from paddle.trainer_config_helpers import *
+
+settings(batch_size=10)
+data = data_layer(name ="input", size=8*16*16)
+proj1 = conv_projection(input=data, filter_size=1, filter_size_y=1,
+ num_channels=8, num_filters=16, stride=1)
+proj2 = conv_projection(input=data, filter_size=1, filter_size_y=1,
+ num_channels=8, num_filters=16, stride=1)
+concat = concat_layer(input=[proj1, proj2], bias_attr=False, act=ReluActivation())
+
+proj = conv_projection(input=data, filter_size=1, filter_size_y=1,
+ num_channels=8, num_filters=16, stride=1)
+
+with mixed_layer(bias_attr=True, act=LinearActivation()) as conv:
+ conv += proj
+
+outputs(concat, conv)
diff --git a/paddle/gserver/tests/test_LayerGrad.cpp b/paddle/gserver/tests/test_LayerGrad.cpp
index eab9bf84141a2..bf2c2e0499941 100644
--- a/paddle/gserver/tests/test_LayerGrad.cpp
+++ b/paddle/gserver/tests/test_LayerGrad.cpp
@@ -134,6 +134,45 @@ TEST(Projection, identity) {
}
}
+
+#ifndef PADDLE_ONLY_CPU
+TEST(Projection, conv) {
+ const int NUM_FILTERS = 16;
+ const int FILTER_SIZE = 2;
+ const int FILTER_SIZE_Y = 3;
+ const int CHANNELS = 3;
+ const int IMAGE_SIZE = 16;
+
+ ProjectionConfig conf;
+ conf.set_type("conv");
+ conf.set_num_filters(NUM_FILTERS);
+
+ ConvConfig* conv = conf.mutable_conv_conf();
+ conv->set_filter_size(FILTER_SIZE);
+ conv->set_filter_size_y(FILTER_SIZE_Y);
+ conv->set_channels(CHANNELS);
+ conv->set_padding(0);
+ conv->set_padding_y(1);
+ conv->set_stride(2);
+ conv->set_stride_y(2);
+ conv->set_groups(1);
+ conv->set_filter_channels(conv->channels() / conv->groups());
+ conv->set_img_size(IMAGE_SIZE);
+ int outputSize = (2 * conv->padding() + conv->img_size() -
+ conv->filter_size()) / conv->stride() + 1;
+ int outputSizeY = (2 * conv->padding_y() + conv->img_size() -
+ conv->filter_size_y()) / conv->stride_y() + 1;
+ conv->set_output_x(outputSize);
+ conf.set_input_size(IMAGE_SIZE * IMAGE_SIZE * CHANNELS);
+ conf.set_output_size(outputSize * outputSizeY * NUM_FILTERS);
+
+ testProjectionGrad(conf, INPUT_DATA,
+ /* parameterSize */ NUM_FILTERS * CHANNELS * FILTER_SIZE * FILTER_SIZE_Y,
+ /* batchSize */ 100, true, false, NUM_FILTERS, true);
+}
+#endif
+
+
TEST(Layer, concat) {
TestConfig config;
config.biasSize = 0;
diff --git a/paddle/gserver/tests/test_NetworkCompare.cpp b/paddle/gserver/tests/test_NetworkCompare.cpp
index b3ef53067301b..8d3eac5aca8d1 100644
--- a/paddle/gserver/tests/test_NetworkCompare.cpp
+++ b/paddle/gserver/tests/test_NetworkCompare.cpp
@@ -236,6 +236,15 @@ TEST(Compare, img_pool) {
compareNetwork(config_file_a, config_file_b);
FLAGS_use_gpu = useGpu;
}
+
+TEST(Compare, img_conv) {
+ std::string config_file_a = "./gserver/tests/img_conv_a.conf";
+ std::string config_file_b = "./gserver/tests/img_conv_b.conf";
+ bool useGpu = FLAGS_use_gpu;
+ FLAGS_use_gpu = true;
+ compareNetwork(config_file_a, config_file_b);
+ FLAGS_use_gpu = useGpu;
+}
#endif
diff --git a/paddle/math/Matrix.cpp b/paddle/math/Matrix.cpp
index 843eabc97d642..aaeae98f0d28b 100644
--- a/paddle/math/Matrix.cpp
+++ b/paddle/math/Matrix.cpp
@@ -340,6 +340,15 @@ void GpuMatrix::addBias(Matrix& b, real scale) {
BaseMatrix::addBias(b, scale);
}
+void GpuMatrix::addSharedBias(Matrix& b, real scale) {
+ CHECK(b.getHeight() == 1) << "the Bias should be a vector";
+ CHECK_LE(b.getWidth(), getWidth());
+ CHECK_EQ(getWidth() % b.getWidth(), 0UL);
+ hl_matrix_add_shared_bias(getData(), b.getData(), b.getWidth(),
+ getHeight(), getWidth(), scale);
+}
+
+
void GpuMatrix::collectBias(Matrix& a, real scale) {
CHECK_EQ(getHeight(), (size_t)1);
CHECK_EQ(width_, a.getWidth());
@@ -354,6 +363,14 @@ void GpuMatrix::collectBias(Matrix& a, real scale) {
}
}
+void GpuMatrix::collectSharedBias(Matrix& a, real scale) {
+ CHECK_EQ(getHeight(), (size_t)1);
+ CHECK_EQ(a.getWidth() % getWidth(), 0UL);
+ hl_matrix_collect_shared_bias(getData(), a.getData(), getWidth(),
+ a.getHeight(), a.getWidth(), scale);
+}
+
+
void GpuMatrix::sequenceAvgForward(Matrix& a,
const IVector& startsPos,
int mode) {
@@ -1983,6 +2000,24 @@ void CpuMatrix::addBias(Matrix& b, real scale) {
}
}
+void CpuMatrix::addSharedBias(Matrix& b, real scale) {
+ CHECK_EQ(b.getHeight(), (size_t)1);
+ real* aData = getData();
+ real* bData = b.getData();
+ size_t numSamples = getHeight();
+ size_t channel = b.getWidth();
+ CHECK_EQ(getWidth() % channel, 0UL);
+ size_t dim = getWidth() / channel;
+
+ for (size_t i = 0; i < numSamples; i++) {
+ for (size_t c = 0; c < channel; c++) {
+ for (size_t j = 0; j < dim; j++) {
+ aData[i * getStride() + c * dim + j] += scale * bData[c];
+ }
+ }
+ }
+}
+
void CpuMatrix::collectBias(Matrix& a, real scale) {
CHECK_EQ(getHeight(), (size_t)1);
CHECK_EQ(width_, a.getWidth());
@@ -2000,6 +2035,23 @@ void CpuMatrix::collectBias(Matrix& a, real scale) {
}
}
+void CpuMatrix::collectSharedBias(Matrix& a, real scale) {
+ CHECK_EQ(getHeight(), (size_t)1);
+ real* B = getData();
+ real* A = a.getData();
+ size_t numSamples = a.getHeight();
+ size_t channel = getWidth();
+ CHECK_EQ(a.getWidth() % channel, 0UL);
+ size_t dim = a.getWidth() / channel;
+ for (size_t i = 0; i < numSamples; i++) {
+ for (size_t c = 0; c < channel; c++) {
+ for (size_t j = 0; j < dim; j++) {
+ B[c] += scale * A[i * channel * dim + c * dim + j];
+ }
+ }
+ }
+}
+
void CpuMatrix::sequenceAvgForward(Matrix& a,
const IVector& startsPos,
int mode) {
diff --git a/paddle/math/Matrix.h b/paddle/math/Matrix.h
index 047c76a8604cc..52cbed528ca8b 100644
--- a/paddle/math/Matrix.h
+++ b/paddle/math/Matrix.h
@@ -343,11 +343,35 @@ class Matrix : public BaseMatrix {
LOG(FATAL) << "Not implemented";
}
+ virtual void addSharedBias(Matrix& b, real scale) {
+ LOG(FATAL) << "Not implemented";
+ }
+
+ virtual void addBias(Matrix& b, real scale, bool sharedBias) {
+ if (!sharedBias) {
+ addBias(b, scale);
+ } else {
+ addSharedBias(b, scale);
+ }
+ }
+
/// add each sample from a to this.
virtual void collectBias(Matrix& a, real scale) {
LOG(FATAL) << "Not implemented";
}
+ virtual void collectSharedBias(Matrix& a, real scale) {
+ LOG(FATAL) << "Not implemented";
+ }
+
+ virtual void collectBias(Matrix& a, real scale, bool sharedBias) {
+ if (!sharedBias) {
+ collectBias(a, scale);
+ } else {
+ collectSharedBias(a, scale);
+ }
+ }
+
virtual void sequenceAvgForward(Matrix& a, const IVector& startsPos,
int mode) {
LOG(FATAL) << "Not implemented";
@@ -1021,6 +1045,7 @@ class GpuMatrix : public Matrix {
/// add b to each sample of this.
void addBias(Matrix& b, real scale);
+ void addSharedBias(Matrix& b, real scale);
/**
* @code
@@ -1028,6 +1053,7 @@ class GpuMatrix : public Matrix {
* @endcode
*/
void collectBias(Matrix& a, real scale);
+ void collectSharedBias(Matrix& a, real scale);
void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode);
@@ -1341,9 +1367,11 @@ class CpuMatrix : public Matrix {
public:
/// add b to each sample of this.
void addBias(Matrix& b, real scale);
+ void addSharedBias(Matrix& b, real scale);
/// add each sample of a to this.
void collectBias(Matrix& a, real scale);
+ void collectSharedBias(Matrix& a, real scale);
void sequenceAvgForward(Matrix& a, const IVector& startsPos, int mode);
diff --git a/paddle/math/tests/test_matrixCompare.cpp b/paddle/math/tests/test_matrixCompare.cpp
index ac160479a9dfc..0ddf7e0dfc386 100644
--- a/paddle/math/tests/test_matrixCompare.cpp
+++ b/paddle/math/tests/test_matrixCompare.cpp
@@ -21,6 +21,8 @@ limitations under the License. */
#include "paddle/math/SparseMatrix.h"
#include
#include "paddle/gserver/tests/TestUtil.h"
+#include "paddle/utils/Stat.h"
+
using namespace paddle; // NOLINT
using namespace std; // NOLINT
@@ -2071,6 +2073,60 @@ TEST(Matrix, MaxOutFwdBwd) {
}
}
+void testAddSharedBias(int numSamples, int dim, int channel) {
+ MatrixPtr cpuData = std::make_shared(numSamples, dim);
+ MatrixPtr gpuData = std::make_shared(numSamples, dim);
+
+ MatrixPtr cpuBias = std::make_shared(1, channel);
+ MatrixPtr gpuBias = std::make_shared(1, channel);
+
+ cpuData->randomizeUniform();
+ gpuData->copyFrom(*cpuData);
+ cpuBias->randomizeUniform();
+ gpuBias->copyFrom(*cpuBias);
+
+ cpuData->addSharedBias(*cpuBias, 1.0);
+ gpuData->addSharedBias(*gpuBias, 1.0);
+
+ MatrixPtr check = std::make_shared(numSamples, dim);
+ check->copyFrom(*gpuData);
+ MatrixCheckErr(*cpuData, *check);
+}
+
+void testCollectSharedBias(int numSamples, int dim, int channel) {
+ MatrixPtr cpuData = std::make_shared(numSamples, dim);
+ MatrixPtr gpuData = std::make_shared(numSamples, dim);
+
+ MatrixPtr cpuBias = std::make_shared(1, channel);
+ MatrixPtr gpuBias = std::make_shared(1, channel);
+
+ cpuData->randomizeUniform();
+ gpuData->copyFrom(*cpuData);
+ cpuBias->randomizeUniform();
+ gpuBias->copyFrom(*cpuBias);
+
+ cpuBias->collectSharedBias(*cpuData, 1.0);
+ gpuBias->collectSharedBias(*gpuData, 1.0);
+
+ MatrixPtr check = std::make_shared(1, channel);
+ check->copyFrom(*gpuBias);
+ MatrixCheckErr(*cpuBias, *check);
+}
+
+
+TEST(Matrix, sharedBias) {
+ for (auto numSamples : {1, 100, 520}) {
+ for (auto dim : {100 * 16, 100 * 32}) {
+ for (auto channel : {8, 16}) {
+ VLOG(3) << " numSamples=" << numSamples << " dim=" << dim
+ << " channel=" << channel;
+ testAddSharedBias(numSamples, dim, channel);
+ testCollectSharedBias(numSamples, dim, channel);
+ }
+ }
+ }
+}
+
int main(int argc, char** argv) {
testing::InitGoogleTest(&argc, argv);
initMain(argc, argv);
diff --git a/paddle/trainer/CMakeLists.txt b/paddle/trainer/CMakeLists.txt
index 08b411d2ccbae..06c019f0a9775 100644
--- a/paddle/trainer/CMakeLists.txt
+++ b/paddle/trainer/CMakeLists.txt
@@ -7,6 +7,7 @@ set(TRAINER_SOURCES
Tester.cpp
Trainer.cpp
TrainerInternal.cpp
+ TrainerBenchmark.cpp
ThreadParameterUpdater.cpp
TrainerInternalConfig.cpp
TrainerConfigHelper.cpp)
diff --git a/paddle/trainer/Trainer.h b/paddle/trainer/Trainer.h
index 4f4811a139e74..7762722456c44 100644
--- a/paddle/trainer/Trainer.h
+++ b/paddle/trainer/Trainer.h
@@ -99,6 +99,7 @@ class Trainer {
void startTrainPass();
void finishTrainPass();
void trainOneDataBatch(DataBatch& dataBatch);
+ void time();
/**
* given a dataBatch and the current parameter value
diff --git a/paddle/trainer/TrainerBenchmark.cpp b/paddle/trainer/TrainerBenchmark.cpp
new file mode 100644
index 0000000000000..54862e95b4a73
--- /dev/null
+++ b/paddle/trainer/TrainerBenchmark.cpp
@@ -0,0 +1,71 @@
+/* Copyright (c) 2016 Baidu, Inc. All Rights Reserve.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#undef PADDLE_DISABLE_TIMER
+
+#include "Trainer.h"
+#include "paddle/utils/Stat.h"
+#include "paddle/utils/Util.h"
+
+P_DECLARE_int32(test_period);
+
+P_DEFINE_bool(feed_data, false, "Wether to read data from DataProvider.");
+
+namespace paddle {
+
+void Trainer::time() {
+ startTrain();
+
+ trainerInternal_.getParameterUpdater()->startPass();
+ evaluator_->start();
+
+ DataBatch dataBatch;
+ int32_t batchSize = config_->getOptConfig().batch_size();
+ int32_t num = dataProvider_->getNextBatch(batchSize, &dataBatch);
+ CHECK_EQ(num, batchSize) << "The sample number is less than batch size "
+ << num << " != " << batchSize;
+
+ CHECK(dataBatch.getSize()) << "No data from data provider";
+
+ std::vector outputs;
+ // burning time
+ LOG(INFO) << "Burning time...";
+ for (int n = 0; n < 10; ++n) {
+ trainerInternal_.trainOneBatch(n, dataBatch, &outputs);
+ }
+ LOG(INFO) << "Burning time end.";
+
+ for (int n = 0; n < FLAGS_test_period; n++) {
+ if (FLAGS_feed_data) {
+ REGISTER_TIMER("GetData");
+ num = dataProvider_->getNextBatch(batchSize, &dataBatch);
+ }
+
+ if (num != batchSize) {
+ break;
+ }
+
+ {
+ REGISTER_TIMER("FwdBwd");
+ trainerInternal_.trainOneBatch(n, dataBatch, &outputs);
+ }
+ }
+ globalStat.setThreadInfo(true);
+ globalStat.printSegTimerStatus();
+ globalStat.reset();
+
+ finishTrain();
+}
+
+} // namespace paddle
diff --git a/paddle/trainer/TrainerMain.cpp b/paddle/trainer/TrainerMain.cpp
index 94266639f94ad..a486cc383ace6 100644
--- a/paddle/trainer/TrainerMain.cpp
+++ b/paddle/trainer/TrainerMain.cpp
@@ -103,6 +103,8 @@ int main(int argc, char** argv) {
trainer.checkGradient();
} else if (FLAGS_job == "test") {
trainer.test();
+ } else if (FLAGS_job == "time") {
+ trainer.time();
} else {
LOG(FATAL) << "Unknown job type: " << FLAGS_job;
}
diff --git a/proto/ModelConfig.proto.m4 b/proto/ModelConfig.proto.m4
index 70c1f8d563238..79e76b6bf1bdd 100644
--- a/proto/ModelConfig.proto.m4
+++ b/proto/ModelConfig.proto.m4
@@ -255,7 +255,7 @@ sinclude(`ModelConfigLayer.proto.m4')
// (which is how convnets are usually trained). Setting this to
// false will untie the biases, yielding a separate bias for
// every location at which the filter is applied.
- optional bool shared_biases = 8;
+ optional bool shared_biases = 8 [default = false];
// Valid values are ones that divide the area of the output
// grid in this convolutional layer. For example if this layer
@@ -379,6 +379,9 @@ sinclude(`ModelConfigLayer.proto.m4')
// use to compute moving mean and variance.
optional real moving_average_fraction = 47 [default = 0.9];
+
+ // bias size
+ optional uint32 bias_size = 48 [default = 0];
}
message EvaluatorConfig {
diff --git a/python/paddle/trainer/config_parser.py b/python/paddle/trainer/config_parser.py
index fe8a5e5d48767..e9098943165fd 100644
--- a/python/paddle/trainer/config_parser.py
+++ b/python/paddle/trainer/config_parser.py
@@ -632,6 +632,44 @@ def calc_parameter_dims(self, input_size, output_size):
_total_pad = 0
+@config_class
+class ConvProjection(Projection):
+ type = 'conv'
+
+ def __init__(
+ self,
+ input_layer_name,
+ num_filters=None,
+ conv_conf=None,
+ **xargs):
+ super(ConvProjection, self).__init__(input_layer_name, **xargs)
+
+ if num_filters is not None:
+ self.proj_conf.num_filters = num_filters
+
+ parse_conv(conv_conf,
+ input_layer_name,
+ self.proj_conf.conv_conf)
+ # TODO: support rectangle input
+ self.proj_conf.output_size = (self.proj_conf.conv_conf.output_x ** 2) * num_filters
+
+ def calc_output_size(self, input_layer_config):
+ return self.proj_conf.output_size
+
+ def calc_parameter_size(self, input_size, output_size):
+ co = self.proj_conf.num_filters
+ ci = self.proj_conf.conv_conf.channels
+ fh = self.proj_conf.conv_conf.filter_size
+ fw = self.proj_conf.conv_conf.filter_size_y
+ return co * ci * fh * fw
+
+ def calc_bias_size(self):
+ return self.proj_conf.num_filters
+
+ def calc_parameter_dims(self, input_size, output_size):
+ return None
+
+
# Define a operator for mixed layer
@config_class
class Operator(Cfg):
@@ -2528,8 +2566,15 @@ def __init__(
record_operator_conf = self.config.operator_confs.add()
record_operator_conf.CopyFrom(operator_conf)
+ psize = self.config.size
+ if isinstance(self.inputs[0], ConvProjection):
+ self.config.shared_biases = True
+ psize = 0
+ for input in self.inputs:
+ psize += input.calc_bias_size()
- self.create_bias_parameter(bias, self.config.size)
+ self.config.bias_size = psize
+ self.create_bias_parameter(bias, psize)
if error_clipping_threshold is not None:
self.config.error_clipping_threshold = error_clipping_threshold
@@ -2547,8 +2592,10 @@ def __init__(
self,
name,
inputs,
+ bias=False,
**xargs):
config_assert(inputs, 'inputs cannot be empty')
+ config_assert(not bias, 'ConcatenateLayer cannot support bias.')
super(ConcatenateLayer, self).__init__(
name, 'concat', 0, inputs=inputs, **xargs)
size = 0
@@ -2567,10 +2614,19 @@ def __init__(
self,
name,
inputs,
+ bias=False,
**xargs):
config_assert(inputs, 'inputs cannot be empty')
super(ConcatenateLayer2, self).__init__(
name, 'concat2', 0, inputs=inputs, **xargs)
+
+ if isinstance(self.inputs[0], ConvProjection):
+ for input_index in xrange(len(self.inputs) - 1):
+ input = self.inputs[input_index + 1]
+ config_assert(isinstance(input, ConvProjection),
+ "The first input of ConcatenateLayer2 is ConvProjection, "
+ "the other inputs should also be ConvProjection.")
+
size = 0
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
@@ -2596,6 +2652,16 @@ def __init__(
input.proj_conf.output_size)
self.create_input_parameter(input_index, psize, dims)
+ psize = self.config.size
+ if isinstance(self.inputs[0], ConvProjection):
+ self.config.shared_biases = True
+ psize = 0
+ for input in self.inputs:
+ psize += input.calc_bias_size()
+
+ self.config.bias_size = psize
+ self.create_bias_parameter(bias, psize)
+
@config_layer('recurrent')
class RecurrentLayer(LayerBase):
def __init__(
diff --git a/python/paddle/trainer_config_helpers/layers.py b/python/paddle/trainer_config_helpers/layers.py
index 7df9108ae82a4..9a23c02431d18 100644
--- a/python/paddle/trainer_config_helpers/layers.py
+++ b/python/paddle/trainer_config_helpers/layers.py
@@ -34,7 +34,7 @@
"table_projection", "mixed_layer", "data_layer",
"embedding_layer", "fc_layer", "grumemory",
"pooling_layer", "lstmemory", "last_seq", "first_seq",
- "cos_sim", "hsigmoid",
+ "cos_sim", "hsigmoid", "conv_projection",
"regression_cost", 'classification_cost', "LayerOutput",
'img_conv_layer', 'img_pool_layer', 'batch_norm_layer',
'img_cmrnorm_layer', 'addto_layer',
@@ -1984,7 +1984,7 @@ def addto_layer(input, act=None, name=None, bias_attr=None,
@wrap_act_default(act=IdentityActivation())
@wrap_name_default("concat")
@layer_support()
-def concat_layer(input, act=None, name=None, layer_attr=None):
+def concat_layer(input, act=None, name=None, layer_attr=None, bias_attr=None):
"""
Concat all input vector into one huge vector.
Inputs can be list of LayerOutput or list of projection.
@@ -2043,10 +2043,14 @@ def __reduce_concat_type__(a, b):
layer_type = (LayerType.CONCAT_LAYER if is_concat_layer
else LayerType.CONCAT_PROJ_LAYER)
+ if layer_type == LayerType.CONCAT_LAYER:
+ assert not bias_attr
+
Layer(
name=name, type=layer_type,
inputs=[x.name for x in input] if is_concat_layer else input,
active_type=act.name,
+ bias=ParamAttr.to_bias(bias_attr),
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
@@ -2950,6 +2954,103 @@ def conv_operator(img, filter, filter_size, num_filters,
op.origin = [img, filter]
return op
+@wrap_param_attr_default()
+def conv_projection(input, filter_size, num_filters,
+ num_channels=None, stride=1, padding=0,
+ filter_size_y=None, stride_y=None, padding_y=None,
+ groups=1, param_attr=None):
+ """
+ ConvProjection with a layer as input.
+ It performs element-wise multiplication with weight.
+
+ Different from img_conv_layer and conv_op, conv_projection is an Projection,
+ which can be used in mixed_layer and conat_layer. It use cudnn to implement
+ conv and only support GPU mode.
+
+ The example usage is:
+
+ .. code-block:: python
+
+ proj = conv_projection(img=input1,
+ filter_size=3,
+ num_filters=64,
+ num_channels=64)
+
+ :param input: input layer
+ :type input: LayerOutput
+ :param filter_size: The x dimension of a filter kernel.
+ :type filter_size: int
+ :param filter_size_y: The y dimension of a filter kernel. Since
+ PaddlePaddle now supports rectangular filters,
+ the filter's shape can be (filter_size, filter_size_y).
+ :type filter_size_y: int
+ :param num_filters: channel of output data.
+ :type num_filters: int
+ :param num_channel: channel of input data.
+ :type num_channel: int
+ :param stride: The x dimension of the stride.
+ :type stride: int
+ :param stride_y: The y dimension of the stride.
+ :type stride_y: int
+ :param padding: The x dimension of padding.
+ :type padding: int
+ :param padding_y: The y dimension of padding.
+ :type padding_y: int
+ :param groups: The group number.
+ :type groups: int
+ :param param_attr: Convolution param attribute. None means default attribute
+ :type param_attr: ParameterAttribute
+ :return: A DotMulProjection Object.
+ :rtype: DotMulProjection
+ """
+ if num_channels is None:
+ assert input.num_filters is not None
+ num_channels = input.num_filters
+
+ if filter_size_y is None:
+ if isinstance(filter_size, collections.Sequence):
+ assert len(filter_size) == 2
+ filter_size, filter_size_y = filter_size
+ else:
+ filter_size_y = filter_size
+
+ if stride_y is None:
+ if isinstance(stride, collections.Sequence):
+ assert len(stride) == 2
+ stride, stride_y = stride
+ else:
+ stride_y = stride
+
+ if padding_y is None:
+ if isinstance(padding, collections.Sequence):
+ assert len(padding) == 2
+ padding, padding_y = padding
+ else:
+ padding_y = padding
+
+ if param_attr.attr.get('initial_smart'):
+ # special initial for conv layers.
+ init_w = (2.0 / (filter_size ** 2 * num_channels)) ** 0.5
+ param_attr.attr["initial_mean"] = 0.0
+ param_attr.attr["initial_std"] = init_w
+ param_attr.attr["initial_strategy"] = 0
+ param_attr.attr["initial_smart"] = False
+
+ proj = ConvProjection(input_layer_name=input.name,
+ num_filters=num_filters,
+ conv_conf=Conv(filter_size=filter_size,
+ padding=padding,
+ stride=stride,
+ channels=num_channels,
+ filter_size_y=filter_size_y,
+ padding_y=padding_y,
+ stride_y=stride_y,
+ groups=groups),
+ **param_attr.attr)
+
+ proj.origin = input
+ return proj
+
@wrap_name_default()
@layer_support()
diff --git a/python/paddle/trainer_config_helpers/networks.py b/python/paddle/trainer_config_helpers/networks.py
index 65512b327cdc6..bce88f93626ec 100644
--- a/python/paddle/trainer_config_helpers/networks.py
+++ b/python/paddle/trainer_config_helpers/networks.py
@@ -29,7 +29,7 @@
"img_conv_bn_pool", 'dropout_layer', 'lstmemory_group',
'lstmemory_unit', 'small_vgg', 'img_conv_group', 'vgg_16_network',
'gru_unit', 'gru_group', 'simple_gru', 'simple_attention',
- 'text_conv_pool',
+ 'simple_gru2', 'bidirectional_gru', 'text_conv_pool',
'bidirectional_lstm', 'inputs', 'outputs']
@@ -811,22 +811,37 @@ def simple_gru(input,
gru_layer_attr=None
):
"""
- simple_gru is also a recurrent layer group version Gated Recurrent Unit as
- gru_group. The difference only lies in implemention details.
+ You maybe see gru_step_layer, grumemory in layers.py, gru_unit, gru_group,
+ simple_gru in network.py. The reason why there are so many interfaces is
+ that we have two ways to implement recurrent neural network. One way is to
+ use one complete layer to implement rnn (including simple rnn, gru and lstm)
+ with multiple time steps, such as recurrent_layer, lstmemory, grumemory. But,
+ the multiplication operation :math:`W x_t` is not computed in these layers.
+ See details in their interfaces in layers.py.
+ The other implementation is to use an recurrent group which can ensemble a
+ series of layers to compute rnn step by step. This way is flexible for
+ attenion mechanism or other complex connections.
+
+ - gru_step_layer: only compute rnn by one step. It needs an memory as input
+ and can be used in recurrent group.
+ - gru_unit: a wrapper of gru_step_layer with memory.
+ - gru_group: a GRU cell implemented by a combination of multiple layers in
+ recurrent group.
+ But :math:`W x_t` is not done in group.
+ - gru_memory: a GRU cell implemented by one layer, which does same calculation
+ with gru_group and is faster than gru_group.
+ - simple_gru: a complete GRU implementation inlcuding :math:`W x_t` and
+ gru_group. :math:`W` contains :math:`W_r`, :math:`W_z` and :math:`W`, see
+ formula in grumemory.
+
The computational speed is that, grumemory is relatively better than
gru_group, and gru_group is relatively better than simple_gru.
- simple_gru does exactly the same calculation as the grumemory layer does.
- Please see grumemory in layers.py for more detail about the maths.
-
The example usage is:
.. code-block:: python
- gru = gur_group(input=[layer1],
- size=256,
- act=TanhActivation(),
- gate_act=SigmoidActivation())
+ gru = simple_gru(input=[layer1], size=256)
:param input: input layer name.
:type input: LayerOutput
@@ -863,6 +878,132 @@ def simple_gru(input,
gru_layer_attr=gru_layer_attr)
+@wrap_name_default('simple_gru2')
+def simple_gru2(input,
+ size,
+ name=None,
+ reverse=False,
+ mixed_param_attr=None,
+ mixed_bias_attr=None,
+ gru_param_attr=None,
+ gru_bias_attr=None,
+ act=None,
+ gate_act=None,
+ mixed_layer_attr=None,
+ gru_cell_attr=None
+ ):
+ """
+ simple_gru2 is the same with simple_gru, but using grumemory instead
+ Please see grumemory in layers.py for more detail about the maths.
+ simple_gru2 is faster than simple_gru.
+
+ The example usage is:
+
+ .. code-block:: python
+
+ gru = simple_gru2(input=[layer1], size=256)
+
+ :param input: input layer name.
+ :type input: LayerOutput
+ :param name: name of the gru group.
+ :type name: basestring
+ :param size: hidden size of the gru.
+ :type size: int
+ :param reverse: whether to process the input data in a reverse order
+ :type reverse: bool
+ :param act: type of the activiation
+ :type act: BaseActivation
+ :param gate_act: type of the gate activiation
+ :type gate_act: BaseActivation
+ :param gru_bias_attr: bias. False means no bias, None means default bias.
+ :type gru_bias_attr: ParameterAttribute|False
+ :param gru_layer_attr: Extra parameter attribute of the gru layer.
+ :type gru_layer_attr: ParameterAttribute|False
+ :return: the gru group.
+ :rtype: LayerOutput
+ """
+ with mixed_layer(name='%s_transform' % name,
+ size=size * 3,
+ bias_attr=mixed_bias_attr,
+ layer_attr=mixed_layer_attr) as m:
+ m += full_matrix_projection(input=input, param_attr=mixed_param_attr)
+
+ return grumemory(name=name,
+ size=size,
+ input=m,
+ reverse=reverse,
+ bias_attr=gru_bias_attr,
+ param_attr=gru_param_attr,
+ act=act,
+ gate_act=gate_act,
+ layer_attr=gru_cell_attr)
+
+
+@wrap_name_default("bidirectional_gru")
+def bidirectional_gru(input, size, name=None, return_seq=False,
+ fwd_mixed_param_attr=None, fwd_mixed_bias_attr=None,
+ fwd_gru_param_attr=None, fwd_gru_bias_attr=None,
+ fwd_act=None, fwd_gate_act=None,
+ fwd_mixed_layer_attr=None, fwd_gru_cell_attr=None,
+
+ bwd_mixed_param_attr=None, bwd_mixed_bias_attr=None,
+ bwd_gru_param_attr=None, bwd_gru_bias_attr=None,
+ bwd_act=None, bwd_gate_act=None,
+ bwd_mixed_layer_attr=None, bwd_gru_cell_attr=None,
+
+ last_seq_attr=None, first_seq_attr=None,
+ concat_attr=None, concat_act=None):
+ """
+ A bidirectional_gru is a recurrent unit that iterates over the input
+ sequence both in forward and bardward orders, and then concatenate two
+ outputs to form a final output. However, concatenation of two outputs
+ is not the only way to form the final output, you can also, for example,
+ just add them together.
+
+ The example usage is:
+
+ .. code-block:: python
+
+ bi_gru = bidirectional_gru(input=[input1], size=512)
+
+ :param name: bidirectional gru layer name.
+ :type name: basestring
+ :param input: input layer.
+ :type input: LayerOutput
+ :param size: gru layer size.
+ :type size: int
+ :param return_seq: If set False, outputs of the last time step are
+ concatenated and returned.
+ If set True, the entire output sequences that are
+ processed in forward and backward directions are
+ concatenated and returned.
+ :type return_seq: bool
+ :return: LayerOutput object.
+ :rtype: LayerOutput
+ """
+ args = locals()
+
+ fw = simple_gru2(name='%s_fw' % name, input=input, size=size,
+ **dict((k[len('fwd_'):], v) for k, v in args.iteritems()
+ if k.startswith('fwd_')))
+
+ bw = simple_gru2(name="%s_bw" % name, input=input, size=size,
+ reverse=True,
+ **dict((k[len('bwd_'):], v) for k, v in args.iteritems()
+ if k.startswith('bwd_')))
+
+ if return_seq:
+ return concat_layer(name=name, input=[fw, bw], layer_attr=concat_attr,
+ act=concat_act)
+ else:
+ fw_seq = last_seq(name="%s_fw_last" % name, input=fw,
+ layer_attr=last_seq_attr)
+ bw_seq = first_seq(name="%s_bw_last" % name, input=bw,
+ layer_attr=first_seq_attr)
+ return concat_layer(name=name, input=[fw_seq, bw_seq],
+ layer_attr=concat_attr, act=concat_act)
+
+
@wrap_name_default("bidirectional_lstm")
def bidirectional_lstm(input, size, name=None, return_seq=False,
fwd_mat_param_attr=None, fwd_bias_param_attr=None,
@@ -893,7 +1034,7 @@ def bidirectional_lstm(input, size, name=None, return_seq=False,
.. code-block:: python
- lstm_step = bidirectional_lstm(input=[input1], size=512)
+ bi_lstm = bidirectional_lstm(input=[input1], size=512)
:param name: bidirectional lstm layer name.
:type name: basestring
@@ -907,7 +1048,7 @@ def bidirectional_lstm(input, size, name=None, return_seq=False,
processed in forward and backward directions are
concatenated and returned.
:type return_seq: bool
- :return: lstm layer name.
+ :return: LayerOutput object accroding to the return_seq.
:rtype: LayerOutput
"""
args = locals()
diff --git a/python/paddle/trainer_config_helpers/tests/configs/check.md5 b/python/paddle/trainer_config_helpers/tests/configs/check.md5
index d1b22b34903df..72dfdad7bdd40 100644
--- a/python/paddle/trainer_config_helpers/tests/configs/check.md5
+++ b/python/paddle/trainer_config_helpers/tests/configs/check.md5
@@ -1,10 +1,11 @@
86c0815275a9d5eb902e23c6a592f58a img_layers.protostr
a5d9259ff1fd7ca23d0ef090052cb1f2 last_first_seq.protostr
9c038249ec8ff719753a746cdb04c026 layer_activations.protostr
-5913f87b39cee3b2701fa158270aca26 projections.protostr
+34e04043cbb12931c47fa44ec50eeffc projections.protostr
7334ba0a4544f0623231330fc51d390d shared_fc.protostr
-8b8b6bb128a7dfcc937be86145f53e2f shared_lstm.protostr
+bb8e233b05b8e07f9ed386b7aee4f2c6 shared_lstm.protostr
6b39e34beea8dfb782bee9bd3dea9eb5 simple_rnn_layers.protostr
+f98e79e1630d5eb827c300e64836d269 test_bi_grumemory.protostr
0fc1409600f1a3301da994ab9d28b0bf test_cost_layers.protostr
6cd5f28a3416344f20120698470e0a4c test_cost_layers_with_weight.protostr
144bc6d3a509de74115fa623741797ed test_expand_layer.protostr
@@ -15,7 +16,7 @@ d350bd91a0dc13e854b1364c3d9339c6 test_lstmemory_layer.protostr
5433ed33d4e7414eaf658f2a55946186 test_maxout.protostr
251a948ba41c1071afcd3d9cf9c233f7 test_ntm_layers.protostr
e6ff04e70aea27c7b06d808cc49c9497 test_print_layer.protostr
-2a75dd33b640c49a8821c2da6e574577 test_rnn_group.protostr
+fded24727338fb8ce44d9951ed8aea08 test_rnn_group.protostr
67d6fde3afb54f389d0ce4ff14726fe1 test_sequence_pooling.protostr
f586a548ef4350ba1ed47a81859a64cb unused_layers.protostr
-8122477f4f65244580cec09edc590041 util_layers.protostr
+f937a5a6e7e8864b4d8cf56b0f7c7f44 util_layers.protostr
diff --git a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
index 4b1d2d3d41d52..6a31ceabdf36d 100755
--- a/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
+++ b/python/paddle/trainer_config_helpers/tests/configs/generate_protostr.sh
@@ -9,7 +9,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight
-test_maxout)
+test_maxout test_bi_grumemory)
for conf in ${configs[*]}
diff --git a/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py b/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py
new file mode 100644
index 0000000000000..ab9f7c4948b85
--- /dev/null
+++ b/python/paddle/trainer_config_helpers/tests/configs/test_bi_grumemory.py
@@ -0,0 +1,10 @@
+from paddle.trainer_config_helpers import *
+
+settings(
+ batch_size=1000,
+ learning_rate=1e-4
+)
+
+din = data_layer(name='data', size=120)
+
+outputs(bidirectional_gru(input=din, size=40, return_seq=True))