From eab6b46ef52fdf3af915445e47aa0f3ff62acc68 Mon Sep 17 00:00:00 2001 From: Andrey Churkin Date: Thu, 23 Jan 2025 11:58:06 +0000 Subject: [PATCH] Update test_examples (#3195) ### Changes Add quantization_aware_training_tensorflow_mobilenet_v2 to test scope --- .../tensorflow/mobilenet_v2/main.py | 9 +++++--- tests/cross_fw/examples/.test_durations | 3 ++- tests/cross_fw/examples/example_scope.json | 21 +++++++++++++++++++ tests/cross_fw/examples/run_example.py | 9 ++++++++ 4 files changed, 38 insertions(+), 4 deletions(-) diff --git a/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py b/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py index bd0f6870b02..cf3bc372887 100644 --- a/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py +++ b/examples/quantization_aware_training/tensorflow/mobilenet_v2/main.py @@ -113,7 +113,7 @@ def preprocess_for_train(image, label): train_dataset = tfds.load("imagenette/320px-v2", split="train", shuffle_files=True, as_supervised=True) -train_dataset = train_dataset.map(preprocess_for_train).shuffle(1024).batch(128) +train_dataset = train_dataset.map(preprocess_for_train).batch(64) val_dataset = tfds.load("imagenette/320px-v2", split="validation", shuffle_files=False, as_supervised=True) val_dataset = val_dataset.map(preprocess_for_eval).batch(128) @@ -150,12 +150,15 @@ def transform_fn(data_item): tf_quantized_model = nncf.quantize(tf_model, calibration_dataset) tf_quantized_model.compile( - optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5), + optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5), loss=tf.keras.losses.CategoricalCrossentropy(), metrics=[tf.keras.metrics.CategoricalAccuracy()], ) -tf_quantized_model.fit(train_dataset, epochs=3, verbose=1) +# To minimize the example's runtime, we train for only 1 epoch. This is sufficient to demonstrate +# that the quantized model produced by QAT is more accurate than the one produced by PTQ. +# However, training for more than 1 epoch would further improve the quantized model's accuracy. +tf_quantized_model.fit(train_dataset, epochs=1, verbose=1) # Removes auxiliary layers and operations added during the quantization process, # resulting in a clean, fully quantized model ready for deployment. diff --git a/tests/cross_fw/examples/.test_durations b/tests/cross_fw/examples/.test_durations index 5bcce770b14..4eaae374e77 100644 --- a/tests/cross_fw/examples/.test_durations +++ b/tests/cross_fw/examples/.test_durations @@ -14,5 +14,6 @@ "tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_torch_anomalib]": 478.797, "tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_torch_resnet18]": 1251.144, "tests/cross_fw/examples/test_examples.py::test_examples[post_training_quantization_torch_fx_resnet18]": 412.243, - "tests/cross_fw/examples/test_examples.py::test_examples[fp8_llm_quantization]": 229.69 + "tests/cross_fw/examples/test_examples.py::test_examples[fp8_llm_quantization]": 229.69, + "tests.cross_fw.examples.test_examples.test_examples[quantization_aware_training_tensorflow_mobilenet_v2]": 1500.00 } diff --git a/tests/cross_fw/examples/example_scope.json b/tests/cross_fw/examples/example_scope.json index f3105f825ed..c406189ac01 100644 --- a/tests/cross_fw/examples/example_scope.json +++ b/tests/cross_fw/examples/example_scope.json @@ -273,5 +273,26 @@ "Tokyo." ] } + }, + "quantization_aware_training_tensorflow_mobilenet_v2": { + "backend": "tf", + "requirements": "examples/quantization_aware_training/tensorflow/mobilenet_v2/requirements.txt", + "cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz", + "accuracy_tolerance": 0.003, + "accuracy_metrics": { + "fp32_top1": 0.987770676612854, + "int8_top1": 0.9737579822540283, + "accuracy_drop": 0.014012694358825684 + }, + "performance_metrics": { + "fp32_fps": 1703.04, + "int8_fps": 5796.3, + "performance_speed_up": 3.403501972942503 + }, + "model_size_metrics": { + "fp32_model_size": 8.596238136291504, + "int8_model_size": 2.69466495513916, + "model_compression_rate": 3.1900953474371994 + } } } diff --git a/tests/cross_fw/examples/run_example.py b/tests/cross_fw/examples/run_example.py index 95b59aa3128..ecdc0137368 100644 --- a/tests/cross_fw/examples/run_example.py +++ b/tests/cross_fw/examples/run_example.py @@ -279,6 +279,15 @@ def quantization_aware_training_torch_anomalib(data: Union[str, None]): } +def quantization_aware_training_tensorflow_mobilenet_v2() -> Dict[str, float]: + import tensorflow_datasets as tfds + + tfds.display_progress_bar(enable=False) + + example_root = str(PROJECT_ROOT / "examples" / "quantization_aware_training" / "tensorflow" / "mobilenet_v2") + return post_training_quantization_mobilenet_v2(example_root) + + def main(argv): parser = ArgumentParser() parser.add_argument("--name", help="Example name", required=True)