Skip to content

Commit

Permalink
[ MSE ] Fix for better MSE loss precision
Browse files Browse the repository at this point in the history
This PR chage the loss computation using full precsion rather than
half precsion to maintain accuracy.

**Changes proposed in this PR:**
- Added TOC generator for README.md

Resolves:

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <[email protected]>
  • Loading branch information
jijoongmoon committed May 17, 2024
1 parent 478e4c6 commit c2b28dd
Showing 1 changed file with 26 additions and 8 deletions.
34 changes: 26 additions & 8 deletions nntrainer/layers/loss/mse_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,27 @@ void MSELossLayer::forwarding(RunLayerContext &context, bool training) {

void MSELossLayer::calcDerivative(RunLayerContext &context) {
Tensor empty_tensor;
Tensor &ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX);
const Tensor &y2_ = context.getIncomingDerivative(SINGLE_INOUT_IDX);
Tensor &y2 = empty_tensor;

if (ret_derivative.getDataType() == ml::train::TensorDim::DataType::FP32)
y2 = y2_;
Tensor &ret_derivative =
context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() ==
ml::train::TensorDim::DataType::FP32
? context.getOutgoingDerivative(SINGLE_INOUT_IDX)
: empty_tensor;

if (y2.empty())
y2 = y2_.clone(ret_derivative.getDataType());
if (ret_derivative.empty())
ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX)
.clone(ml::train::TensorDim::DataType::FP32);

Tensor &y = context.getInput(SINGLE_INOUT_IDX);
Tensor &y = context.getInput(SINGLE_INOUT_IDX).getDataType() ==
ml::train::TensorDim::DataType::FP32
? context.getInput(SINGLE_INOUT_IDX)
: empty_tensor;

if (y.empty())
y = context.getInput(SINGLE_INOUT_IDX)
.clone(ml::train::TensorDim::DataType::FP32);

const Tensor &y2 = context.getIncomingDerivative(SINGLE_INOUT_IDX);

y.subtract(y2, ret_derivative);
float divider = ((float)y.size()) / 2;
Expand All @@ -70,7 +80,15 @@ void MSELossLayer::calcDerivative(RunLayerContext &context) {
"[MSELossLayer::calcDerivative] Error when calculating loss");
}

// Loss Scale needs Full precsiion of ret_derivative. Therefore,
// ret_derivateive should be FP32 when applying scale, and after applying it
// need to convert original type for backpropagating.

LossLayer::applyLossScale(context, ret_derivative);

if (context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() !=
ml::train::TensorDim::DataType::FP32)
context.getOutgoingDerivative(SINGLE_INOUT_IDX).copyData(ret_derivative);
}

} // namespace nntrainer

0 comments on commit c2b28dd

Please sign in to comment.