Skip to content

Commit

Permalink
Update test_examples (#3195)
Browse files Browse the repository at this point in the history
### Changes

Add quantization_aware_training_tensorflow_mobilenet_v2 to test scope
  • Loading branch information
andrey-churkin authored Jan 23, 2025
1 parent f574a1f commit eab6b46
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def preprocess_for_train(image, label):


train_dataset = tfds.load("imagenette/320px-v2", split="train", shuffle_files=True, as_supervised=True)
train_dataset = train_dataset.map(preprocess_for_train).shuffle(1024).batch(128)
train_dataset = train_dataset.map(preprocess_for_train).batch(64)

val_dataset = tfds.load("imagenette/320px-v2", split="validation", shuffle_files=False, as_supervised=True)
val_dataset = val_dataset.map(preprocess_for_eval).batch(128)
Expand Down Expand Up @@ -150,12 +150,15 @@ def transform_fn(data_item):
tf_quantized_model = nncf.quantize(tf_model, calibration_dataset)

tf_quantized_model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-5),
optimizer=tf.keras.optimizers.RMSprop(learning_rate=1e-5),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=[tf.keras.metrics.CategoricalAccuracy()],
)

tf_quantized_model.fit(train_dataset, epochs=3, verbose=1)
# To minimize the example's runtime, we train for only 1 epoch. This is sufficient to demonstrate
# that the quantized model produced by QAT is more accurate than the one produced by PTQ.
# However, training for more than 1 epoch would further improve the quantized model's accuracy.
tf_quantized_model.fit(train_dataset, epochs=1, verbose=1)

# Removes auxiliary layers and operations added during the quantization process,
# resulting in a clean, fully quantized model ready for deployment.
Expand Down
3 changes: 2 additions & 1 deletion tests/cross_fw/examples/.test_durations
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,6 @@
"tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_torch_anomalib]": 478.797,
"tests/cross_fw/examples/test_examples.py::test_examples[quantization_aware_training_torch_resnet18]": 1251.144,
"tests/cross_fw/examples/test_examples.py::test_examples[post_training_quantization_torch_fx_resnet18]": 412.243,
"tests/cross_fw/examples/test_examples.py::test_examples[fp8_llm_quantization]": 229.69
"tests/cross_fw/examples/test_examples.py::test_examples[fp8_llm_quantization]": 229.69,
"tests.cross_fw.examples.test_examples.test_examples[quantization_aware_training_tensorflow_mobilenet_v2]": 1500.00
}
21 changes: 21 additions & 0 deletions tests/cross_fw/examples/example_scope.json
Original file line number Diff line number Diff line change
Expand Up @@ -273,5 +273,26 @@
"Tokyo."
]
}
},
"quantization_aware_training_tensorflow_mobilenet_v2": {
"backend": "tf",
"requirements": "examples/quantization_aware_training/tensorflow/mobilenet_v2/requirements.txt",
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
"accuracy_tolerance": 0.003,
"accuracy_metrics": {
"fp32_top1": 0.987770676612854,
"int8_top1": 0.9737579822540283,
"accuracy_drop": 0.014012694358825684
},
"performance_metrics": {
"fp32_fps": 1703.04,
"int8_fps": 5796.3,
"performance_speed_up": 3.403501972942503
},
"model_size_metrics": {
"fp32_model_size": 8.596238136291504,
"int8_model_size": 2.69466495513916,
"model_compression_rate": 3.1900953474371994
}
}
}
9 changes: 9 additions & 0 deletions tests/cross_fw/examples/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,15 @@ def quantization_aware_training_torch_anomalib(data: Union[str, None]):
}


def quantization_aware_training_tensorflow_mobilenet_v2() -> Dict[str, float]:
import tensorflow_datasets as tfds

tfds.display_progress_bar(enable=False)

example_root = str(PROJECT_ROOT / "examples" / "quantization_aware_training" / "tensorflow" / "mobilenet_v2")
return post_training_quantization_mobilenet_v2(example_root)


def main(argv):
parser = ArgumentParser()
parser.add_argument("--name", help="Example name", required=True)
Expand Down

0 comments on commit eab6b46

Please sign in to comment.