From 6fd235a5b0eaa938e320071ee9a08c9e08298e07 Mon Sep 17 00:00:00 2001 From: mlahariya <40852060+mlahariya@users.noreply.github.com> Date: Wed, 20 Nov 2024 13:00:23 +0100 Subject: [PATCH] Update minor changes --- tests/ml_tools/test_callbacks.py | 86 ++++++++++++++++---------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/tests/ml_tools/test_callbacks.py b/tests/ml_tools/test_callbacks.py index fe5d13d88..4cb6911d4 100644 --- a/tests/ml_tools/test_callbacks.py +++ b/tests/ml_tools/test_callbacks.py @@ -1,29 +1,34 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock + import pytest import torch -from pathlib import Path from torch.utils.data import DataLoader -from unittest.mock import Mock + +from qadence.ml_tools import TrainConfig, Trainer from qadence.ml_tools.callbacks import ( - SaveCheckpoint, - SaveBestCheckpoint, - PrintMetrics, - WriteMetrics, - PlotMetrics, - LogHyperparameters, LoadCheckpoint, + LogHyperparameters, LogModelTracker, + PlotMetrics, + PrintMetrics, + SaveBestCheckpoint, + SaveCheckpoint, + WriteMetrics, ) -from qadence.ml_tools.data import to_dataloader -from qadence.ml_tools import TrainConfig, Trainer -from qadence.ml_tools.stages import TrainingStage -from qadence.ml_tools.data import OptimizeResult from qadence.ml_tools.callbacks.saveload import write_checkpoint +from qadence.ml_tools.data import OptimizeResult, to_dataloader +from qadence.ml_tools.stages import TrainingStage + def dataloader(batch_size: int = 25) -> DataLoader: x = torch.linspace(0, 1, batch_size).reshape(-1, 1) y = torch.cos(x) return to_dataloader(x, y, batch_size=batch_size, infinite=True) + @pytest.fixture def trainer(Basic: torch.nn.Module, tmp_path: Path) -> Trainer: """Set up a real Trainer with a Basic and optimizer.""" @@ -34,15 +39,11 @@ def trainer(Basic: torch.nn.Module, tmp_path: Path) -> Trainer: log_folder=tmp_path, max_iter=1, checkpoint_best_only=True, - validation_criterion= lambda loss, best, ep : loss < (best - ep), + validation_criterion=lambda loss, best, ep: loss < (best - ep), val_epsilon=1e-5, ) trainer = Trainer( - model=model, - optimizer=optimizer, - config=config, - loss_fn="mse", - train_dataloader=data + model=model, optimizer=optimizer, config=config, loss_fn="mse", train_dataloader=data ) trainer.opt_result = OptimizeResult( iteration=1, @@ -55,26 +56,32 @@ def trainer(Basic: torch.nn.Module, tmp_path: Path) -> Trainer: return trainer -def test_save_checkpoint(trainer : Trainer): +def test_save_checkpoint(trainer: Trainer) -> None: + writer = trainer.callback_manager.writer = Mock() stage = trainer.training_stage callback = SaveCheckpoint(stage, called_every=1) - callback(stage, trainer, trainer.config, None) + callback(stage, trainer, trainer.config, writer) - checkpoint_file = trainer.config.log_folder / f"model_{type(trainer.model).__name__}_ckpt_001_device_cpu.pt" + checkpoint_file = ( + trainer.config.log_folder / f"model_{type(trainer.model).__name__}_ckpt_001_device_cpu.pt" + ) assert checkpoint_file.exists() -def test_save_best_checkpoint(trainer: Trainer): +def test_save_best_checkpoint(trainer: Trainer) -> None: + writer = trainer.callback_manager.writer = Mock() stage = trainer.training_stage callback = SaveBestCheckpoint(on=stage, called_every=1) - callback(stage, trainer, trainer.config, None) + callback(stage, trainer, trainer.config, writer) - best_checkpoint_file = trainer.config.log_folder / f"model_{type(trainer.model).__name__}_ckpt_best_device_cpu.pt" + best_checkpoint_file = ( + trainer.config.log_folder / f"model_{type(trainer.model).__name__}_ckpt_best_device_cpu.pt" + ) assert best_checkpoint_file.exists() - assert callback.best_loss == trainer.opt_result.loss.item() + assert callback.best_loss == trainer.opt_result.loss -def test_print_metrics(trainer: Trainer): +def test_print_metrics(trainer: Trainer) -> None: writer = trainer.callback_manager.writer = Mock() stage = trainer.training_stage callback = PrintMetrics(on=stage, called_every=1) @@ -82,7 +89,7 @@ def test_print_metrics(trainer: Trainer): writer.print_metrics.assert_called_once_with(trainer.opt_result) -def test_write_metrics(trainer: Trainer): +def test_write_metrics(trainer: Trainer) -> None: writer = trainer.callback_manager.writer = Mock() stage = trainer.training_stage callback = WriteMetrics(on=stage, called_every=1) @@ -90,15 +97,13 @@ def test_write_metrics(trainer: Trainer): writer.write.assert_called_once_with(trainer.opt_result) -def test_plot_metrics(trainer: Trainer): - trainer.config.plotting_functions = [ - lambda model, iteration: ("plot_name", None) - ] +def test_plot_metrics(trainer: Trainer) -> None: + trainer.config.plotting_functions = (lambda model, iteration: ("plot_name", None),) writer = trainer.callback_manager.writer = Mock() stage = trainer.training_stage callback = PlotMetrics(stage, called_every=1) callback(stage, trainer, trainer.config, writer) - + writer.plot.assert_called_once_with( trainer.model, trainer.opt_result.iteration, @@ -106,7 +111,7 @@ def test_plot_metrics(trainer: Trainer): ) -def test_log_hyperparameters(trainer: Trainer): +def test_log_hyperparameters(trainer: Trainer) -> None: writer = trainer.callback_manager.writer = Mock() stage = trainer.training_stage trainer.config.hyperparams = {"learning_rate": 0.01, "epochs": 10} @@ -115,25 +120,20 @@ def test_log_hyperparameters(trainer: Trainer): writer.log_hyperparams.assert_called_once_with(trainer.config.hyperparams) -def test_load_checkpoint(trainer: Trainer): +def test_load_checkpoint(trainer: Trainer) -> None: # Prepare a checkpoint - write_checkpoint(trainer.config.log_folder, - trainer.model, - trainer.optimizer, - iteration=1) - + write_checkpoint(trainer.config.log_folder, trainer.model, trainer.optimizer, iteration=1) + writer = trainer.callback_manager.writer = Mock() stage = trainer.training_stage callback = LoadCheckpoint(stage, called_every=1) - model, optimizer, iteration = callback( - stage, trainer, trainer.config, None - ) + model, optimizer, iteration = callback(stage, trainer, trainer.config, writer) assert model is not None assert optimizer is not None assert iteration == 1 -def test_log_model_tracker(trainer: Trainer): +def test_log_model_tracker(trainer: Trainer) -> None: writer = trainer.callback_manager.writer = Mock() callback = LogModelTracker(on=trainer.training_stage, called_every=1) callback(trainer.training_stage, trainer, trainer.config, writer)