diff --git a/nntrainer/layers/loss/mse_loss_layer.cpp b/nntrainer/layers/loss/mse_loss_layer.cpp index 3aed8125e0..ed4390655d 100644 --- a/nntrainer/layers/loss/mse_loss_layer.cpp +++ b/nntrainer/layers/loss/mse_loss_layer.cpp @@ -51,17 +51,27 @@ void MSELossLayer::forwarding(RunLayerContext &context, bool training) { void MSELossLayer::calcDerivative(RunLayerContext &context) { Tensor empty_tensor; - Tensor &ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX); - const Tensor &y2_ = context.getIncomingDerivative(SINGLE_INOUT_IDX); - Tensor &y2 = empty_tensor; - if (ret_derivative.getDataType() == ml::train::TensorDim::DataType::FP32) - y2 = y2_; + Tensor &ret_derivative = + context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getOutgoingDerivative(SINGLE_INOUT_IDX) + : empty_tensor; - if (y2.empty()) - y2 = y2_.clone(ret_derivative.getDataType()); + if (ret_derivative.empty()) + ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); - Tensor &y = context.getInput(SINGLE_INOUT_IDX); + Tensor &y = context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (y.empty()) + y = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + + const Tensor &y2 = context.getIncomingDerivative(SINGLE_INOUT_IDX); y.subtract(y2, ret_derivative); float divider = ((float)y.size()) / 2; @@ -70,7 +80,15 @@ void MSELossLayer::calcDerivative(RunLayerContext &context) { "[MSELossLayer::calcDerivative] Error when calculating loss"); } + // Loss Scale needs Full precsiion of ret_derivative. Therefore, + // ret_derivateive should be FP32 when applying scale, and after applying it + // need to convert original type for backpropagating. + LossLayer::applyLossScale(context, ret_derivative); + + if (context.getOutgoingDerivative(SINGLE_INOUT_IDX).getDataType() != + ml::train::TensorDim::DataType::FP32) + context.getOutgoingDerivative(SINGLE_INOUT_IDX).copyData(ret_derivative); } } // namespace nntrainer