diff --git a/tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc b/tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc index f73ceed57c2..7c0cb69b5bf 100644 --- a/tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc +++ b/tensorflow/lite/micro/kernels/cmsis_nn/batch_matmul.cc @@ -121,7 +121,12 @@ inline TfLiteStatus PopulateEvalData( RuntimeShape tmp_r = SwapRowColumnDims(*rhs_shape); rhs_shape->ReplaceWith(tmp_r.DimensionsCount(), tmp_r.DimsData()); } - if (!params->adj_x) { + // ReferenceOps and CMSIS-NN have different requirements for when the + // lhs shape should be transposed, so we have to treat float differently. + if (!params->adj_x && original_lhs_input->type == kTfLiteFloat32) { + RuntimeShape tmp_l = SwapRowColumnDims(*lhs_shape); + lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData()); + } else if (params->adj_x && original_lhs_input->type != kTfLiteFloat32) { RuntimeShape tmp_l = SwapRowColumnDims(*lhs_shape); lhs_shape->ReplaceWith(tmp_l.DimensionsCount(), tmp_l.DimsData()); }