diff --git a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp index 6adf6768d96c..804d2e019d94 100644 --- a/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp +++ b/integrations/tensorflow/iree_tf_compiler/TFL/Passes.cpp @@ -52,9 +52,13 @@ void buildTFLImportPassPipeline(OpPassManager &pm) { // Convert all TFL ops to TOSA ops //---------------------------------------------------------------------------- - mlir::tosa::TOSATFTFLLegalizationPipelineOptions tosaOptions; pm.addPass(createLowerGlobalTensorsPass()); + + mlir::tosa::TOSATFTFLLegalizationPipelineOptions tosaOptions; + // Temporary work-around for https://github.com/google/iree/issues/8974 + tosaOptions.dequantize_tfl_softmax = true; mlir::tosa::createTFTFLtoTOSALegalizationPipeline(pm, tosaOptions); + pm.nest().addPass(mlir::tosa::createStripQuantTypesPass()); pm.addPass(createCanonicalizerPass()); pm.addPass(createReconcileUnrealizedCastsPass()); diff --git a/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py b/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py index fa95b2322018..259aa1f386a5 100644 --- a/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py +++ b/integrations/tensorflow/test/python/iree_tfl_tests/mobilebert_tf2_quant_test.py @@ -35,11 +35,16 @@ def generate_inputs(self, input_details): def compare_results(self, iree_results, tflite_results, details): super(MobileBertTest, self).compare_results(iree_results, tflite_results, details) - # We have confirmed in large scale accuracy tests that differences this large is acceptable. + # We have confirmed in large scale accuracy tests that differences as large + # as 5.0 is acceptable. We later further relaxed from 5.0 to 7.0 in + # https://github.com/google/iree/pull/9337 when quantized Softmax got + # de-quantized, which should be numerically correct albeit not bit-exact. + # The actual observed max error was ~ 6.36. The value 7.0 is that rounded up + # to the next integer. self.assertTrue( - np.isclose(iree_results[0], tflite_results[0], atol=5.0).all()) + np.isclose(iree_results[0], tflite_results[0], atol=7.0).all()) self.assertTrue( - np.isclose(iree_results[1], tflite_results[1], atol=5.0).all()) + np.isclose(iree_results[1], tflite_results[1], atol=7.0).all()) def test_compile_tflite(self): self.compile_and_execute()