Skip to content

Commit

Permalink
Merge pull request #1149 from reyoung/feature/ErrorHandlingInPaddle
Browse files Browse the repository at this point in the history
Feature/error handling in paddle
  • Loading branch information
reyoung authored Jan 19, 2017
2 parents 3fff0af + 843fb2e commit 7f0ad62
Show file tree
Hide file tree
Showing 12 changed files with 340 additions and 58 deletions.
129 changes: 100 additions & 29 deletions paddle/gserver/activations/ActivationFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,14 @@ static ClassRegistrar<ActivationFunction> gActivationRegistrar;
class IdentityActivation : public ActivationFunction {
public:
static const std::string name;
void forward(Argument& act) { (void)act; }
void backward(Argument& act) { (void)act; }
Error __must_check forward(Argument& act) {
(void)act;
return Error();
}
Error __must_check backward(Argument& act) {
(void)act;
return Error();
}
const std::string& getName() const { return name; }
};
const std::string IdentityActivation::name = "";
Expand All @@ -86,8 +92,14 @@ static InitFunction __reg_activation__identity([] {
* \f]
*/
BEGIN_DEFINE_ACTIVATION(sigmoid)
void forward(Argument& act) { act.value->sigmoid(*act.value); }
void backward(Argument& act) { act.grad->sigmoidDerivative(*act.value); }
Error __must_check forward(Argument& act) {
act.value->sigmoid(*act.value);
return Error();
}
Error __must_check backward(Argument& act) {
act.grad->sigmoidDerivative(*act.value);
return Error();
}
END_DEFINE_ACTIVATION(sigmoid)

/**
Expand All @@ -103,9 +115,12 @@ MatrixPtr sftMaxDot_;
MatrixPtr one_;

public:
void forward(Argument& act) { act.value->softmax(*act.value); }
Error __must_check forward(Argument& act) {
act.value->softmax(*act.value);
return Error();
}

void backward(Argument& act) {
Error __must_check backward(Argument& act) {
MatrixPtr outputV = act.value;
MatrixPtr outputG = act.grad;

Expand Down Expand Up @@ -137,6 +152,7 @@ void backward(Argument& act) {

act.grad->softmaxDerivative(*act.value, *sftMaxSum_);
}
return Error();
}
END_DEFINE_ACTIVATION(softmax)

Expand All @@ -151,8 +167,11 @@ ACTIVATION_CLASS_NAME(softmax) softmax_;
Argument argument_;

public:
void forward(Argument& act) {
CHECK_EQ(act.value->getWidth(), 1UL);
Error __must_check forward(Argument& act) {
if (act.value->getWidth() != 1UL) {
return Error(
"Input width for each timestep of sequence softmax should be 1");
}

if (!argument_.value) {
argument_.value = Matrix::create(nullptr,
Expand All @@ -169,10 +188,14 @@ void forward(Argument& act) {

auto starts = act.sequenceStartPositions->getVector(useGpu(act.deviceId));
act.value->sequenceSoftmax(*act.value, *starts);
return Error();
}

void backward(Argument& act) {
CHECK_EQ(act.grad->getWidth(), 1UL);
Error __must_check backward(Argument& act) {
if (act.value->getWidth() != 1UL) {
return Error(
"Input width for each timestep of sequence softmax should be 1");
}

size_t numSequences = act.getNumSequences();
const int* starts = act.sequenceStartPositions->getData(false);
Expand All @@ -184,8 +207,10 @@ void backward(Argument& act) {
argument_.value->setData(act.value->getData() + offset, 1UL, size);
argument_.grad->setData(act.grad->getData() + offset, 1UL, size);

softmax_.backward(argument_);
Error status = softmax_.backward(argument_);
if (!status) return status;
}
return Error();
}
END_DEFINE_ACTIVATION(sequence_softmax)

Expand All @@ -200,9 +225,15 @@ END_DEFINE_ACTIVATION(sequence_softmax)
* 0 otherwise.
*/
BEGIN_DEFINE_ACTIVATION(relu)
void forward(Argument& act) { act.value->relu(*act.value); }
Error __must_check forward(Argument& act) {
act.value->relu(*act.value);
return Error();
}

void backward(Argument& act) { act.grad->reluDerivative(*act.value); }
Error __must_check backward(Argument& act) {
act.grad->reluDerivative(*act.value);
return Error();
}
END_DEFINE_ACTIVATION(relu)

/**
Expand All @@ -219,9 +250,15 @@ END_DEFINE_ACTIVATION(relu)
* TODO(yuyang18): Remove magic number 24 or make it configuable.
*/
BEGIN_DEFINE_ACTIVATION(brelu)
void forward(Argument& act) { act.value->brelu(*act.value); }
Error __must_check forward(Argument& act) {
act.value->brelu(*act.value);
return Error();
}

void backward(Argument& act) { act.grad->breluDerivative(*act.value); }
Error __must_check backward(Argument& act) {
act.grad->breluDerivative(*act.value);
return Error();
}
END_DEFINE_ACTIVATION(brelu)

/**
Expand All @@ -231,9 +268,15 @@ END_DEFINE_ACTIVATION(brelu)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(tanh)
void forward(Argument& act) { act.value->tanh(*act.value); }
Error __must_check forward(Argument& act) {
act.value->tanh(*act.value);
return Error();
}

void backward(Argument& act) { act.grad->tanhDerivative(*act.value); }
Error __must_check backward(Argument& act) {
act.grad->tanhDerivative(*act.value);
return Error();
}
END_DEFINE_ACTIVATION(tanh)

/**
Expand All @@ -248,10 +291,14 @@ real a, b;

public:
ACTIVATION_CLASS_NAME(stanh)() : a(1.7159), b(2. / 3.) {}
void forward(Argument& act) { act.value->scaledTanh(*act.value, a, b); }
Error __must_check forward(Argument& act) {
act.value->scaledTanh(*act.value, a, b);
return Error();
}

void backward(Argument& act) {
Error __must_check backward(Argument& act) {
act.grad->scaledTanhDerivative(*act.value, a, b);
return Error();
}
END_DEFINE_ACTIVATION(stanh)

Expand All @@ -262,9 +309,15 @@ END_DEFINE_ACTIVATION(stanh)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(softrelu)
void forward(Argument& act) { act.value->softrelu(*act.value); }
Error __must_check forward(Argument& act) {
act.value->softrelu(*act.value);
return Error();
}

void backward(Argument& act) { act.grad->softreluDerivative(*act.value); }
Error __must_check backward(Argument& act) {
act.grad->softreluDerivative(*act.value);
return Error();
}
END_DEFINE_ACTIVATION(softrelu)

/**
Expand All @@ -280,7 +333,7 @@ END_DEFINE_ACTIVATION(softrelu)
* 0 if z=0
*/
BEGIN_DEFINE_ACTIVATION(abs)
void forward(Argument& act) {
Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in,
act.value->getHeight(),
Expand All @@ -290,9 +343,13 @@ void forward(Argument& act) {

act.in->copyFrom(*act.value);
act.value->abs2(*act.value);
return Error();
}

void backward(Argument& act) { act.grad->absDerivative(*act.in); }
Error __must_check backward(Argument& act) {
act.grad->absDerivative(*act.in);
return Error();
}
END_DEFINE_ACTIVATION(abs)

/**
Expand All @@ -302,7 +359,7 @@ END_DEFINE_ACTIVATION(abs)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(square)
void forward(Argument& act) {
Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in,
act.value->getHeight(),
Expand All @@ -312,9 +369,13 @@ void forward(Argument& act) {

act.in->copyFrom(*act.value);
act.value->square2(*act.value);
return Error();
}

void backward(Argument& act) { act.grad->squareDerivative(*act.in); }
Error __must_check backward(Argument& act) {
act.grad->squareDerivative(*act.in);
return Error();
}
END_DEFINE_ACTIVATION(square)

/**
Expand All @@ -324,9 +385,15 @@ END_DEFINE_ACTIVATION(square)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(exponential)
void forward(Argument& act) { act.value->exp2(*act.value); }
Error __must_check forward(Argument& act) {
act.value->exp2(*act.value);
return Error();
}

void backward(Argument& act) { act.grad->expDerivative(*act.value); }
Error __must_check backward(Argument& act) {
act.grad->expDerivative(*act.value);
return Error();
}
END_DEFINE_ACTIVATION(exponential)

/**
Expand All @@ -336,7 +403,7 @@ END_DEFINE_ACTIVATION(exponential)
* \f]
*/
BEGIN_DEFINE_ACTIVATION(log)
void forward(Argument& act) {
Error __must_check forward(Argument& act) {
SetDevice device(act.deviceId);
Matrix::resizeOrCreate(act.in,
act.value->getHeight(),
Expand All @@ -346,9 +413,13 @@ void forward(Argument& act) {

act.in->copyFrom(*act.value);
act.value->log2(*act.value);
return Error();
}

void backward(Argument& act) { act.grad->dotDiv(*act.grad, *act.in); }
Error __must_check backward(Argument& act) {
act.grad->dotDiv(*act.grad, *act.in);
return Error();
}
END_DEFINE_ACTIVATION(log)

ActivationFunction* ActivationFunction::create(const std::string& type) {
Expand Down
5 changes: 3 additions & 2 deletions paddle/gserver/activations/ActivationFunction.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/utils/Error.h"

namespace paddle {

Expand Down Expand Up @@ -48,7 +49,7 @@ class ActivationFunction {
*
* Usually, act is Layer::output_
*/
virtual void forward(Argument& act) = 0;
virtual Error __must_check forward(Argument& act) = 0;

/**
* @brief Backward propagaion
Expand All @@ -57,7 +58,7 @@ class ActivationFunction {
* - Before calling backward(), act.grad = dE / dy, where E is the error/cost
* - After backward() returns, act.grad = dE / dx = (dE/dy) * (dy/dx)
*/
virtual void backward(Argument& act) = 0;
virtual Error __must_check backward(Argument& act) = 0;

virtual const std::string& getName() const = 0;
};
Expand Down
7 changes: 5 additions & 2 deletions paddle/gserver/layers/Layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/utils/Util.h"

#include "paddle/math/SparseMatrix.h"
#include "paddle/utils/Error.h"
#include "paddle/utils/Logging.h"

#include "AddtoLayer.h"
Expand Down Expand Up @@ -334,7 +335,8 @@ void Layer::showOutputStats() {

void Layer::forwardActivation() {
/* activation */
activation_->forward(output_);
auto status = activation_->forward(output_);
status.check();

/* dropout */
if (config_.drop_rate() > 0) {
Expand Down Expand Up @@ -372,7 +374,8 @@ void Layer::backwardActivation() {
oGrad->dotMul(*oGrad, *dropOutMask_);
}

activation_->backward(output_);
auto status = activation_->backward(output_);
status.check();
}

void Layer::forwardDropOut() {
Expand Down
25 changes: 15 additions & 10 deletions paddle/gserver/layers/MDLstmLayer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,9 +506,12 @@ void MDLstmLayer::forwardGate2OutputSequence(int start,
*frameState_[start + preOffsetV[i]].value, *checkFgOneDim, 1.0, 1.0);
}
}
activationGate_->forward(frameInputGate_[idxCurr]);
activationGate_->forward(frameForgetGate_[idxCurr]);
activation_->forward(frameInputNode_[idxCurr]);
auto status = activationGate_->forward(frameInputGate_[idxCurr]);
status.check();
status = activationGate_->forward(frameForgetGate_[idxCurr]);
status.check();
status = activation_->forward(frameInputNode_[idxCurr]);
status.check();

frameState_[idxCurr].value->zeroMem();
for (int i = 0; i < numDims_; i++) {
Expand All @@ -530,10 +533,12 @@ void MDLstmLayer::forwardGate2OutputSequence(int start,

frameOutputGate_[idxCurr].value->addDotMul(
*frameState_[idxCurr].value, *checkOg_, 1.0, 1.0);
activationGate_->forward(frameOutputGate_[idxCurr]);
status = activationGate_->forward(frameOutputGate_[idxCurr]);
status.check();

framePreOutput_[idxCurr].value->copyFrom(*(frameState_[idxCurr].value));
activationState_->forward(framePreOutput_[idxCurr]);
status = activationState_->forward(framePreOutput_[idxCurr]);
status.check();

frameOutput_[idxCurr].value->dotMul(*framePreOutput_[idxCurr].value,
*frameOutputGate_[idxCurr].value);
Expand Down Expand Up @@ -640,12 +645,12 @@ void MDLstmLayer::backwardGate2OutputSequence(int start,

framePreOutput_[idxCurr].grad->dotMul(*frameOutput_[idxCurr].grad,
*frameOutputGate_[idxCurr].value);
activationState_->backward(framePreOutput_[idxCurr]);
activationState_->backward(framePreOutput_[idxCurr]).check();
frameState_[idxCurr].grad->copyFrom(*(framePreOutput_[idxCurr].grad));

frameOutputGate_[idxCurr].grad->dotMul(*frameOutput_[idxCurr].grad,
*framePreOutput_[idxCurr].value);
activationGate_->backward(frameOutputGate_[idxCurr]);
activationGate_->backward(frameOutputGate_[idxCurr]).check();

frameState_[idxCurr].grad->addDotMul(
*frameOutputGate_[idxCurr].grad, *checkOg_, 1.0, 1.0);
Expand Down Expand Up @@ -702,9 +707,9 @@ void MDLstmLayer::backwardGate2OutputSequence(int start,
}
}

activationGate_->backward(frameInputGate_[idxCurr]);
activationGate_->backward(frameForgetGate_[idxCurr]);
activation_->backward(frameInputNode_[idxCurr]);
activationGate_->backward(frameInputGate_[idxCurr]).check();
activationGate_->backward(frameForgetGate_[idxCurr]).check();
activation_->backward(frameInputNode_[idxCurr]).check();

if (bias_->getWGrad()) {
for (int i = 0; i < numDims_; i++) {
Expand Down
Loading

0 comments on commit 7f0ad62

Please sign in to comment.