diff --git a/Applications/Custom/mae_loss.cpp b/Applications/Custom/mae_loss.cpp index 092f762cf6..6bc77796a8 100644 --- a/Applications/Custom/mae_loss.cpp +++ b/Applications/Custom/mae_loss.cpp @@ -24,7 +24,18 @@ static constexpr size_t SINGLE_INOUT_IDX = 0; void MaeLossLayer::forwarding(nntrainer::RunLayerContext &context, bool training) { - nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor empty_tensor; + + nntrainer::Tensor &predicted = + context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (predicted.empty()) + predicted = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + nntrainer::Tensor &output = context.getOutput(SINGLE_INOUT_IDX); if (!context.executeInPlace()) @@ -32,7 +43,18 @@ void MaeLossLayer::forwarding(nntrainer::RunLayerContext &context, } void MaeLossLayer::calcDerivative(nntrainer::RunLayerContext &context) { - nntrainer::Tensor &predicted = context.getInput(SINGLE_INOUT_IDX); + nntrainer::Tensor empty_tensor; + + nntrainer::Tensor &predicted = + context.getInput(SINGLE_INOUT_IDX).getDataType() == + ml::train::TensorDim::DataType::FP32 + ? context.getInput(SINGLE_INOUT_IDX) + : empty_tensor; + + if (predicted.empty()) + predicted = context.getInput(SINGLE_INOUT_IDX) + .clone(ml::train::TensorDim::DataType::FP32); + nntrainer::Tensor &label = context.getLabel(SINGLE_INOUT_IDX); nntrainer::Tensor &deriv = context.getOutgoingDerivative(SINGLE_INOUT_IDX);