Skip to content

Commit

Permalink
preliminary implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
yangyaming committed May 15, 2017
1 parent 1ba8206 commit 6adf4ac
Showing 1 changed file with 132 additions and 10 deletions.
142 changes: 132 additions & 10 deletions paddle/gserver/layers/ConvShiftLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ class ConvShiftLayer : public Layer {

void forward(PassType passType) override;
void backward(const UpdateCallback& callback = nullptr) override;
bool isSeqType();
void circularConvSeq();
void circularConvSeqDerivative();
};

REGISTER_LAYER(conv_shift, ConvShiftLayer);
Expand All @@ -66,42 +69,161 @@ bool ConvShiftLayer::init(const LayerMap& layerMap,
return true;
}

bool ConvShiftLayer::isSeqType() {
const Argument& inLayer0 = getInput(0);
if (nullptr == inLayer0.sequenceStartPositions)
return false;
else
return true;
}

void ConvShiftLayer::circularConvSeq() {
const Argument& inLayer0 = getInput(0);
MatrixPtr in0 = inLayer0.value;
MatrixPtr in1 = getInputValue(1);
MatrixPtr out = getOutputValue();
const ICpuGpuVectorPtr& sequenceStartPositions =
inLayer0.sequenceStartPositions;

size_t width0 = in0->getWidth();
size_t numSeqs = sequenceStartPositions->getSize() - 1;
size_t height0 = in0->getHeight();
size_t width1 = in1->getWidth();
size_t height1 = in1->getHeight();

CHECK_EQ(numSeqs, height1);
CHECK_EQ(width0, out->getWidth());
CHECK_EQ(height0, out->getHeight());

CHECK_EQ(width1 % 2, 1U);

real* inV0 = in0->getData();
const int* startPosIntPtr = sequenceStartPositions->getData(false);
real* inV1 = in1->getData();
real* outV = out->getData();

int leftCtxLen = (width1 - 1) / 2;
for (size_t x = 0; x < numSeqs - 1; x++) {
int curSeqLen = startPosIntPtr[x + 1];
size_t curSeqWidth = curSeqLen * width0;
for (size_t i = 0; i < curSeqWidth; i++) {
for (size_t j = 0; j < width1; ++j) {
int index = i + j - leftCtxLen;
index = (index + curSeqWidth) % curSeqWidth;
int outVRowOffset = i / width0;
int outVColOffset = i % width0;
int inV0RowOffset = index / width0;
int inV0ColOffset = index % width0;
(outV + outVRowOffset)[outVColOffset] +=
(inV0 + inV0RowOffset)[inV0ColOffset] * inV1[j];
}
}
outV += curSeqWidth;
inV0 += curSeqWidth;
inV1 += width1;
}
}

void ConvShiftLayer::circularConvSeqDerivative() {
const Argument& inLayer0 = getInput(0);
MatrixPtr in0 = inLayer0.value;
MatrixPtr in1 = getInputValue(1);
MatrixPtr inG0 = getInputGrad(0);
MatrixPtr inG1 = getInputGrad(1);
MatrixPtr outG = getOutputGrad();
const ICpuGpuVectorPtr& sequenceStartPositions =
inLayer0.sequenceStartPositions;

size_t height0 = in0->getHeight();
size_t height1 = in1->getHeight();
size_t numSeqs = sequenceStartPositions->getSize() - 1;
size_t width0 = in0->getWidth();
size_t width1 = in1->getWidth();

CHECK_EQ(height1, numSeqs);
CHECK_EQ(height0, inG0->getHeight());
CHECK_EQ(width0, inG0->getWidth());
CHECK_EQ(height1, inG1->getHeight());
CHECK_EQ(width1, inG1->getWidth());
CHECK_EQ(height0, outG->getHeight());
CHECK_EQ(width0, outG->getWidth());

const int* startPosIntPtr = sequenceStartPositions->getData(false);
real* outGV = outG->getData();
real* inV0 = in0->getData();
real* inV1 = in1->getData();
real* inGV0 = inG0->getData();
real* inGV1 = inG1->getData();

int leftCtxLen = (width1 - 1) / 2;
for (size_t x = 0; x < numSeqs - 1; x++) {
int curSeqLen = startPosIntPtr[x + 1];
size_t curSeqWidth = curSeqLen * width0;
for (size_t j = 0; j < width1; j++) {
for (size_t i = 0; i < curSeqWidth; i++) {
int index = i + j - leftCtxLen;
index = (index + curSeqWidth) % curSeqWidth;
int inGV0RowOffset = index / width0;
int inGV0ColOffset = index % width0;
int outGVRowOffset = i / width0;
int outGVColOffset = i % width0;
(inGV0 + inGV0RowOffset)[inGV0ColOffset] +=
(outGV + outGVRowOffset)[outGVColOffset] * inV1[j];
inGV1[j] += (outGV + outGVRowOffset)[outGVColOffset] *
(inGV0 + inGV0RowOffset)[inGV0ColOffset];
}
}
outGV += curSeqWidth;
inV0 += curSeqWidth;
inV1 += width1;
inGV0 += curSeqWidth;
inGV1 += width1;
}
}

void ConvShiftLayer::forward(PassType passType) {
Layer::forward(passType);

MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);

size_t batchSize = inV0->getHeight();
size_t dataDim = inV0->getWidth();

CHECK_EQ(batchSize, inV1->getHeight());
CHECK_EQ(dataDim, getSize());

{
REGISTER_TIMER_INFO("FwResetTimer", getName().c_str());
resetOutput(batchSize, dataDim);
}

MatrixPtr outV = getOutputValue();

REGISTER_TIMER_INFO("FwConvShiftTimer", getName().c_str());
outV->circularConv(*inV0, *inV1);
if (!isSeqType()) {
MatrixPtr inV1 = getInputValue(1);
CHECK_EQ(batchSize, inV1->getHeight());
MatrixPtr outV = getOutputValue();
outV->circularConv(*inV0, *inV1);
} else {
circularConvSeq();
}
}

void ConvShiftLayer::backward(const UpdateCallback& callback) {
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr outG = getOutputGrad();
MatrixPtr inG0 = getInputGrad(0);
MatrixPtr inG1 = getInputGrad(1);

REGISTER_TIMER_INFO("BwConvShiftTimer", getName().c_str());

if (inG0 && inG1) {
if (!(inG0 && inG1)) {
CHECK(!inG0 || !inG1) << "Not supported";
}

if (!isSeqType()) {
MatrixPtr inV0 = getInputValue(0);
MatrixPtr inV1 = getInputValue(1);
MatrixPtr outG = getOutputGrad();
outG->circularConvDerivative(*outG, *inV0, *inV1, *inG0, *inG1);
} else {
CHECK(!inG0 || !inG1) << "Not supported";
circularConvSeqDerivative();
}
}

Expand Down

0 comments on commit 6adf4ac

Please sign in to comment.