diff --git a/docs/tutorials/advanced_tutorials/custom-models.md b/docs/tutorials/advanced_tutorials/custom-models.md index e9854e2b..3cc6746d 100644 --- a/docs/tutorials/advanced_tutorials/custom-models.md +++ b/docs/tutorials/advanced_tutorials/custom-models.md @@ -119,6 +119,7 @@ This model can then be trained with the standard Qadence helper functions. ```python exec="on" source="material-block" result="json" session="custom-model" from qadence import run from qadence.ml_tools import Trainer, TrainConfig +Trainer.set_use_grad(True) criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-1) diff --git a/docs/tutorials/qml/ml_tools/callbacks.md b/docs/tutorials/qml/ml_tools/callbacks.md index 59232415..383af0bf 100644 --- a/docs/tutorials/qml/ml_tools/callbacks.md +++ b/docs/tutorials/qml/ml_tools/callbacks.md @@ -22,7 +22,7 @@ Qadence ml_tools offers several built-in callbacks for common tasks like saving Prints metrics at specified intervals. -```python +```python exec="on" source="material-block" html="1" from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import PrintMetrics @@ -38,7 +38,8 @@ config = TrainConfig( Writes metrics to a specified logging destination. -```python +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import WriteMetrics write_metrics_callback = WriteMetrics(on="train_epoch_end", called_every=50) @@ -53,7 +54,8 @@ config = TrainConfig( Plots metrics based on user-defined plotting functions. -```python +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import PlotMetrics plot_metrics_callback = PlotMetrics(on="train_epoch_end", called_every=100) @@ -68,7 +70,8 @@ config = TrainConfig( Logs hyperparameters to keep track of training settings. -```python +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import LogHyperparameters log_hyper_callback = LogHyperparameters(on="train_start", called_every=1) @@ -83,7 +86,8 @@ config = TrainConfig( Saves model checkpoints at specified intervals. -```python +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import SaveCheckpoint save_checkpoint_callback = SaveCheckpoint(on="train_epoch_end", called_every=100) @@ -98,7 +102,8 @@ config = TrainConfig( Saves the best model checkpoint based on a validation criterion. -```python +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import SaveBestCheckpoint save_best_checkpoint_callback = SaveBestCheckpoint(on="val_epoch_end", called_every=10) @@ -113,7 +118,8 @@ config = TrainConfig( Loads a saved model checkpoint at the start of training. -```python +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import LoadCheckpoint load_checkpoint_callback = LoadCheckpoint(on="train_start") @@ -128,7 +134,8 @@ config = TrainConfig( Logs the model structure and parameters. -```python +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import LogModelTracker log_model_callback = LogModelTracker(on="train_end") @@ -152,7 +159,7 @@ There are two main ways to define a callback: #### Example 1: Providing a Callback Function Directly -```python +```python exec="on" source="material-block" html="1" from qadence.ml_tools.callbacks import Callback # Define a custom callback function @@ -161,15 +168,14 @@ def custom_callback_function(trainer, config, writer): # Create the callback instance custom_callback = Callback( - on="on_train_end", - called_every=5, + on="train_end", callback=custom_callback_function ) ``` #### Example 2: Subclassing the Callback -```python +```python exec="on" source="material-block" html="1" from qadence.ml_tools.callbacks import Callback class CustomCallback(Callback): @@ -177,7 +183,7 @@ class CustomCallback(Callback): print("Custom behavior in run_callback method.") # Create the subclassed callback instance -custom_callback = CustomCallback(on="on_train_end", called_every=10) +custom_callback = CustomCallback(on="train_batch_end", called_every=10) ``` @@ -185,15 +191,15 @@ custom_callback = CustomCallback(on="on_train_end", called_every=10) To use callbacks in `TrainConfig`, add them to the `callbacks` list when configuring the training process. -```python +```python exec="on" source="material-block" html="1" from qadence.ml_tools import TrainConfig from qadence.ml_tools.callbacks import SaveCheckpoint, PrintMetrics config = TrainConfig( max_iter=10000, callbacks=[ - SaveCheckpoint(on="on_val_epoch_end", called_every=50), - PrintMetrics(on="on_train_epoch_end", called_every=100), + SaveCheckpoint(on="val_epoch_end", called_every=50), + PrintMetrics(on="train_epoch_end", called_every=100), ] ) ``` @@ -217,9 +223,9 @@ These defaults handle common needs, but you can also add custom callbacks to any To create a custom `Trainer` that includes a `PrintMetrics` callback executed specifically at the end of each epoch, follow the steps below. -```python +```python exec="on" source="material-block" html="1" from qadence.ml_tools.trainer import Trainer -from qadence.ml_tools.callback import PrintMetrics +from qadence.ml_tools.callbacks import PrintMetrics class CustomTrainer(Trainer): def __init__(self, *args, **kwargs): diff --git a/docs/tutorials/qml/ml_tools/data_and_config.md b/docs/tutorials/qml/ml_tools/data_and_config.md index 9411e338..d9662b16 100644 --- a/docs/tutorials/qml/ml_tools/data_and_config.md +++ b/docs/tutorials/qml/ml_tools/data_and_config.md @@ -86,7 +86,7 @@ n_epochs = 100 print_parameters = lambda opt_res: print(opt_res.model.parameters()) condition_print = lambda opt_res: opt_res.loss < 1.0e-03 modify_extra_opt_res = {"n_epochs": n_epochs} -custom_callback = Callback( on="on_train_end", callback = print_parameters, callback_condition=condition_print, modify_optimize_result=modify_extra_opt_res, called_every=10,) +custom_callback = Callback(on="train_end", callback = print_parameters, callback_condition=condition_print, modify_optimize_result=modify_extra_opt_res, called_every=10,) config = TrainConfig( folder="some_path/", @@ -170,7 +170,7 @@ def callback_fn(trainer, config, writer): if trainer.opt_res.loss < 0.001: print("Custom Callback: Loss threshold reached!") -custom_callback = Callback(on = "on_train_epoch_end", called_every = 10, callback_function = callback_fn ) +custom_callback = Callback(on = "train_epoch_end", called_every = 10, callback_function = callback_fn ) config = TrainConfig(callbacks=[custom_callback]) ``` diff --git a/qadence/ml_tools/callbacks/callback.py b/qadence/ml_tools/callbacks/callback.py index f77c981d..3d586518 100644 --- a/qadence/ml_tools/callbacks/callback.py +++ b/qadence/ml_tools/callbacks/callback.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Callable, Optional, Union +from typing import Any, Callable from qadence.ml_tools.callbacks.saveload import load_checkpoint, write_checkpoint from qadence.ml_tools.callbacks.writer_registry import BaseWriter @@ -24,9 +24,9 @@ class Callback: "val_batch_start", "val_batch_end", "test_batch_start", "test_batch_end"] called_every (int): Frequency of callback calls in terms of iterations. - callback (Optional[CallbackFunction]): The function to call if the condition is met. - callback_condition (Optional[CallbackConditionFunction]): Condition to check before calling. - modify_optimize_result (Optional[Union[CallbackFunction, dict[str, Any]]]): + callback (CallbackFunction | None): The function to call if the condition is met. + callback_condition (CallbackConditionFunction | None): Condition to check before calling. + modify_optimize_result (CallbackFunction | dict[str, Any] | None): Function to modify `OptimizeResult`. A callback can be defined in two ways: @@ -81,14 +81,14 @@ def __init__( self, on: str | TrainingStage = "idle", called_every: int = 1, - callback: Union[CallbackFunction, None] = None, - callback_condition: Union[CallbackConditionFunction, None] = None, - modify_optimize_result: Optional[Union[CallbackFunction, dict[str, Any]]] = None, + callback: CallbackFunction | None = None, + callback_condition: CallbackConditionFunction | None = None, + modify_optimize_result: CallbackFunction | dict[str, Any] | None = None, ): if not isinstance(called_every, int): raise ValueError("called_every must be a positive integer or 0") - self.callback: Union[CallbackFunction, None] = callback + self.callback: CallbackFunction | None = callback self.on: str | TrainingStage = on self.called_every: int = called_every self.callback_condition = callback_condition or (lambda _: True) diff --git a/qadence/ml_tools/callbacks/writer_registry.py b/qadence/ml_tools/callbacks/writer_registry.py index 4ca5c6f3..92a0b12f 100644 --- a/qadence/ml_tools/callbacks/writer_registry.py +++ b/qadence/ml_tools/callbacks/writer_registry.py @@ -46,7 +46,7 @@ class BaseWriter(ABC): run: Run # [attr-defined] @abstractmethod - def open(self, config: TrainConfig, iteration: int = None) -> Any: + def open(self, config: TrainConfig, iteration: int | None = None) -> Any: """ Opens the writer and prepares it for logging. @@ -262,7 +262,7 @@ def __init__(self) -> None: self.run: Run self.mlflow: ModuleType - def open(self, config: TrainConfig, iteration: int = None) -> ModuleType | None: + def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType | None: """ Opens the MLflow writer and initializes an MLflow run. diff --git a/qadence/ml_tools/loss/loss.py b/qadence/ml_tools/loss/loss.py index 9892c72b..d1bff72c 100644 --- a/qadence/ml_tools/loss/loss.py +++ b/qadence/ml_tools/loss/loss.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Callable, Dict, Union +from typing import Callable import torch import torch.nn as nn @@ -8,7 +8,7 @@ def mse_loss( model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor] -) -> tuple[torch.Tensor, Dict[str, float]]: +) -> tuple[torch.Tensor, dict[str, float]]: """Computes the Mean Squared Error (MSE) loss between model predictions and targets. Args: @@ -18,9 +18,9 @@ def mse_loss( - targets (torch.Tensor): The ground truth labels. Returns: - Tuple[torch.Tensor, Dict[str, float]]: + Tuple[torch.Tensor, dict[str, float]]: - loss (torch.Tensor): The computed MSE loss value. - - metrics (Dict[str, float]): A dictionary with the MSE loss value. + - metrics (dict[str, float]): A dictionary with the MSE loss value. """ criterion = nn.MSELoss() inputs, targets = batch @@ -33,7 +33,7 @@ def mse_loss( def cross_entropy_loss( model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor] -) -> tuple[torch.Tensor, Dict[str, float]]: +) -> tuple[torch.Tensor, dict[str, float]]: """Computes the Cross Entropy loss between model predictions and targets. Args: @@ -43,9 +43,9 @@ def cross_entropy_loss( - targets (torch.Tensor): The ground truth labels. Returns: - Tuple[torch.Tensor, Dict[str, float]]: + Tuple[torch.Tensor, dict[str, float]]: - loss (torch.Tensor): The computed Cross Entropy loss value. - - metrics (Dict[str, float]): A dictionary with the Cross Entropy loss value. + - metrics (dict[str, float]): A dictionary with the Cross Entropy loss value. """ criterion = nn.CrossEntropyLoss() inputs, targets = batch @@ -56,12 +56,12 @@ def cross_entropy_loss( return loss, metrics -def get_loss_fn(loss_fn: Union[None, Callable, str]) -> Callable: +def get_loss_fn(loss_fn: str | Callable | None) -> Callable: """ Returns the appropriate loss function based on the input argument. Args: - loss_fn (Union[None, Callable, str]): The loss function to use. + loss_fn (str | Callable | None): The loss function to use. - If `loss_fn` is a callable, it will be returned directly. - If `loss_fn` is a string, it should be one of: - "mse": Returns the `mse_loss` function. diff --git a/qadence/ml_tools/printing.py b/qadence/ml_tools/printing.py deleted file mode 100644 index 89bfb1f1..00000000 --- a/qadence/ml_tools/printing.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import annotations - -from logging import getLogger -from typing import Any, Callable, Union - -from matplotlib.figure import Figure -from torch import Tensor -from torch.nn import Module -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter - -from qadence.ml_tools.data import DictDataLoader -from qadence.types import ExperimentTrackingTool - -logger = getLogger(__name__) - -PlottingFunction = Callable[[Module, int], tuple[str, Figure]] -InputData = Union[Tensor, dict[str, Tensor]] - - -def print_metrics(loss: float | None, metrics: dict, iteration: int) -> None: - msg = " ".join( - [f"Iteration {iteration: >7} | Loss: {loss:.7f} -"] - + [f"{k}: {v.item():.7f}" for k, v in metrics.items()] - ) - print(msg) - - -def write_tensorboard( - writer: SummaryWriter, loss: float = None, metrics: dict | None = None, iteration: int = 0 -) -> None: - metrics = metrics or dict() - if loss is not None: - writer.add_scalar("loss", loss, iteration) - for key, arg in metrics.items(): - writer.add_scalar(key, arg, iteration) - - -def log_hyperparams_tensorboard(writer: SummaryWriter, hyperparams: dict, metrics: dict) -> None: - writer.add_hparams(hyperparams, metrics) - - -def plot_tensorboard( - writer: SummaryWriter, - model: Module, - iteration: int, - plotting_functions: tuple[PlottingFunction], -) -> None: - for pf in plotting_functions: - descr, fig = pf(model, iteration) - writer.add_figure(descr, fig, global_step=iteration) - - -def log_model_tensorboard( - writer: SummaryWriter, - model: Module, - dataloader: Union[None, DataLoader, DictDataLoader], -) -> None: - logger.warning("Model logging is not supported by tensorboard. No model will be logged.") - - -def write_mlflow(writer: Any, loss: float | None, metrics: dict, iteration: int) -> None: - writer.log_metrics({"loss": float(loss)}, step=iteration) # type: ignore - writer.log_metrics(metrics, step=iteration) # logs the single metrics - - -def log_hyperparams_mlflow(writer: Any, hyperparams: dict, metrics: dict) -> None: - writer.log_params(hyperparams) # type: ignore - - -def plot_mlflow( - writer: Any, - model: Module, - iteration: int, - plotting_functions: tuple[PlottingFunction], -) -> None: - for pf in plotting_functions: - descr, fig = pf(model, iteration) - writer.log_figure(fig, descr) - - -def log_model_mlflow( - writer: Any, model: Module, dataloader: DataLoader | DictDataLoader | None -) -> None: - signature = None - if dataloader is not None: - xs: InputData - xs, *_ = next(iter(dataloader)) - preds = model(xs) - if isinstance(xs, Tensor): - xs = xs.numpy() - preds = preds.detach().numpy() - elif isinstance(xs, dict): - for key, val in xs.items(): - xs[key] = val.numpy() - for key, val in preds.items(): - preds[key] = val.detach.numpy() - - try: - from mlflow.models import infer_signature - - signature = infer_signature(xs, preds) - except ImportError: - logger.warning( - "An MLFlow specific function has been called but MLFlow failed to import." - "Please install MLFlow or adjust your code." - ) - - writer.pytorch.log_model(model, artifact_path="model", signature=signature) - - -TRACKER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = { - ExperimentTrackingTool.TENSORBOARD: write_tensorboard, - ExperimentTrackingTool.MLFLOW: write_mlflow, -} - -LOGGER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = { - ExperimentTrackingTool.TENSORBOARD: log_hyperparams_tensorboard, - ExperimentTrackingTool.MLFLOW: log_hyperparams_mlflow, -} - -PLOTTER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = { - ExperimentTrackingTool.TENSORBOARD: plot_tensorboard, - ExperimentTrackingTool.MLFLOW: plot_mlflow, -} - -MODEL_LOGGER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = { - ExperimentTrackingTool.TENSORBOARD: log_model_tensorboard, - ExperimentTrackingTool.MLFLOW: log_model_mlflow, -} - - -def write_tracker( - *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD -) -> None: - return TRACKER_MAPPING[tracking_tool](*args) - - -def log_tracker( - *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD -) -> None: - return LOGGER_MAPPING[tracking_tool](*args) - - -def plot_tracker( - *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD -) -> None: - return PLOTTER_MAPPING[tracking_tool](*args) - - -def log_model_tracker( - *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD -) -> None: - return MODEL_LOGGER_MAPPING[tracking_tool](*args) diff --git a/qadence/ml_tools/stages.py b/qadence/ml_tools/stages.py index 470e6fc0..aff30fad 100644 --- a/qadence/ml_tools/stages.py +++ b/qadence/ml_tools/stages.py @@ -1,4 +1,3 @@ - from __future__ import annotations from qadence.types import StrEnum diff --git a/qadence/ml_tools/train_utils/base_trainer.py b/qadence/ml_tools/train_utils/base_trainer.py index bd359217..152bd15e 100644 --- a/qadence/ml_tools/train_utils/base_trainer.py +++ b/qadence/ml_tools/train_utils/base_trainer.py @@ -2,7 +2,7 @@ from contextlib import contextmanager from logging import getLogger -from typing import Any, Callable, Iterator, Optional, Union +from typing import Any, Callable, Iterator import nevergrad as ng import torch @@ -39,15 +39,16 @@ class BaseTrainer: Attributes: use_grad (bool): Indicates if gradients are used for optimization. Default is True. - _model (nn.Module): The neural network model. - _optimizer (Union[optim.Optimizer, NGOptimizer, None]]): The optimizer for training. - _config (TrainConfig): The configuration settings for training. - _train_dataloader (Optional[DataLoader]): DataLoader for training data. - _val_dataloader (Optional[DataLoader]): DataLoader for validation data. - _test_dataloader (Optional[DataLoader]): DataLoader for testing data. + model (nn.Module): The neural network model. + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. + config (TrainConfig): The configuration settings for training. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. optimize_step (Callable): Function for performing an optimization step. - loss_fn (Union[Callable, str, None]): loss function to use. + loss_fn (Callable | str ]): loss function to use. Default loss function + used is 'mse' num_training_batches (int): Number of training batches. In case of InfiniteTensorDataset only 1 batch per epoch is used. @@ -64,43 +65,44 @@ class BaseTrainer: def __init__( self, model: nn.Module, - optimizer: Union[optim.Optimizer, NGOptimizer, None], + optimizer: optim.Optimizer | NGOptimizer | None, config: TrainConfig, - loss_fn: Union[None, Callable, str], + loss_fn: str | Callable = "mse", optimize_step: Callable = optimize_step, - train_dataloader: DataLoader = None, - val_dataloader: DataLoader = None, - test_dataloader: DataLoader = None, - max_batches: int = None, + train_dataloader: DataLoader | None = None, + val_dataloader: DataLoader | None = None, + test_dataloader: DataLoader | None = None, + max_batches: int | None = None, ): """ Initializes the BaseTrainer. Args: model (nn.Module): The model to train. - optimizer (Union[optim.Optimizer, NGOptimizer, None]): The optimizer + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. config (TrainConfig): The TrainConfig settings for training. - loss_fn (Union[None, Callable, str]): The loss function to use. + loss_fn (str | Callable): The loss function to use. str input to be specified to use a default loss function. - currently supported loss functions: 'mse', 'cross_entropy' - train_dataloader (Optional[DataLoader]): DataLoader for training data. + currently supported loss functions: 'mse', 'cross_entropy'. + If not specified, default mse loss will be used. + train_dataloader (DataLoader | None): DataLoader for training data. If the model does not need data to evaluate loss, no dataset should be provided. - val_dataloader (Optional[DataLoader]): DataLoader for validation data. - test_dataloader (Optional[DataLoader]): DataLoader for testing data. - max_batches (Optional[int]): Maximum number of batches to process per epoch. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. + max_batches (int | None): Maximum number of batches to process per epoch. This is only valid in case of finite TensorDataset dataloaders. if max_batches is not None, the maximum number of batches used will be min(max_batches, len(dataloader.dataset)) In case of InfiniteTensorDataset only 1 batch per epoch is used. """ self._model: nn.Module - self._optimizer: Union[optim.Optimizer, NGOptimizer, None] + self._optimizer: optim.Optimizer | NGOptimizer | None self._config: TrainConfig - self._train_dataloader: Optional[DataLoader] = None - self._val_dataloader: Optional[DataLoader] = None - self._test_dataloader: Optional[DataLoader] = None + self._train_dataloader: DataLoader | None = None + self._val_dataloader: DataLoader | None = None + self._test_dataloader: DataLoader | None = None self.config = config self.model = model @@ -138,33 +140,31 @@ def model(self, model: nn.Module) -> None: Sets the model, ensuring it is an instance of nn.Module. Args: - model (Optional[nn.Module]): The neural network model. + model (nn.Module): The neural network model. """ if model is not None and not isinstance(model, nn.Module): raise TypeError("model must be an instance of nn.Module or None.") self._model = model @property - def optimizer(self) -> Union[optim.Optimizer, NGOptimizer, None]: + def optimizer(self) -> optim.Optimizer | NGOptimizer | None: """ Returns the optimizer if set, otherwise raises an error. Returns: - Union[optim.Optimizer, NGOptimizer]: The optimizer. + optim.Optimizer | NGOptimizer | None: The optimizer. """ - if self._optimizer is None: - raise ValueError("Optimizer has not been set.") return self._optimizer @optimizer.setter - def optimizer(self, optimizer: Union[optim.Optimizer, NGOptimizer, None]) -> None: + def optimizer(self, optimizer: optim.Optimizer | NGOptimizer | None) -> None: """ Sets the optimizer, checking compatibility with gradient use. We also set up the budget/behavior of different optimizers here. Args: - optimizer (Union[optim.Optimizer, NGOptimizer]): The optimizer for training. + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. """ if optimizer is not None: if self.use_grad: @@ -263,7 +263,7 @@ def config(self, value: TrainConfig) -> None: Sets the training configuration and initializes callback and config managers. Args: - value (Optional[TrainConfig]): The configuration object. + value (TrainConfig): The configuration object. """ if value and not isinstance(value, TrainConfig): raise TypeError("config must be an instance of TrainConfig.") @@ -363,12 +363,12 @@ def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: return decorator @contextmanager - def enable_grad_opt(self, optimizer: optim.Optimizer = None) -> Iterator[None]: + def enable_grad_opt(self, optimizer: optim.Optimizer | None = None) -> Iterator[None]: """ Context manager to temporarily enable gradient-based optimization. Args: - optimizer Optional(optim.Optimizer): The PyTorch optimizer to use. + optimizer (optim.Optimizer): The PyTorch optimizer to use. If no optimizer is provided, default optimizer for trainer object will be used. """ @@ -385,12 +385,12 @@ def enable_grad_opt(self, optimizer: optim.Optimizer = None) -> Iterator[None]: self.optimizer = original_optimizer @contextmanager - def disable_grad_opt(self, optimizer: NGOptimizer = None) -> Iterator[None]: + def disable_grad_opt(self, optimizer: NGOptimizer | None = None) -> Iterator[None]: """ Context manager to temporarily disable gradient-based optimization. Args: - optimizer Optional(NGOptimizer): The Nevergrad optimizer to use. + optimizer (NGOptimizer): The Nevergrad optimizer to use. If no optimizer is provided, default optimizer for trainer object will be used. """ @@ -413,16 +413,18 @@ def on_train_start(self) -> None: def on_train_end( self, train_losses: list[list[tuple[torch.Tensor, Any]]], - val_losses: Optional[list[list[tuple[torch.Tensor, Any]]]] = None, + val_losses: list[list[tuple[torch.Tensor, Any]]] | None = None, ) -> None: """ Called at the end of training. Args: - train_losses: Metrics for the training losses. + train_losses (list[list[tuple[torch.Tensor, Any]]]): + Metrics for the training losses. list -> list -> tuples Epochs -> Training Batches -> (loss, metrics) - val_losses: Metrics for the validation losses. + val_losses (list[list[tuple[torch.Tensor, Any]]] | None): + Metrics for the validation losses. list -> list -> tuples Epochs -> Validation Batches -> (loss, metrics) """ diff --git a/qadence/ml_tools/train_utils/config_manager.py b/qadence/ml_tools/train_utils/config_manager.py index a601a29a..ff7cce00 100644 --- a/qadence/ml_tools/train_utils/config_manager.py +++ b/qadence/ml_tools/train_utils/config_manager.py @@ -4,7 +4,6 @@ import os from logging import getLogger from pathlib import Path -from typing import Union from torch import Tensor @@ -61,12 +60,12 @@ def _initialize_folder(self) -> None: if self.config.folder: self.config._log_folder = self._create_log_folder(self.config.folder) - def _create_log_folder(self, root_folder: Union[str, Path]) -> Path: + def _create_log_folder(self, root_folder: str | Path) -> Path: """ Create a log folder in the specified root folder, adding subfolders if required. Args: - root_folder (Union[str, Path]): The root folder where the log folder will be created. + root_folder (str | Path): The root folder where the log folder will be created. Returns: Path: The path to the created log folder. diff --git a/qadence/ml_tools/trainer.py b/qadence/ml_tools/trainer.py index 44406981..5b85ac4b 100644 --- a/qadence/ml_tools/trainer.py +++ b/qadence/ml_tools/trainer.py @@ -3,7 +3,7 @@ import copy from itertools import islice from logging import getLogger -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union, cast +from typing import Any, Callable, Iterable, cast import torch from nevergrad.optimization.base import Optimizer as NGOptimizer @@ -39,19 +39,19 @@ class Trainer(BaseTrainer): global_step (int): The global step across all epochs. log_device (str): Device for logging, default is "cpu". device (torch_device): Device used for computation. - dtype (torch_dtype): Data type used for computation. - data_dtype (Optional[torch_dtype]): Data type for data. + dtype (torch_dtype | None): Data type used for computation. + data_dtype (torch_dtype | None): Data type for data. Depends on the model's data type. Inherited Attributes: use_grad (bool): Indicates if gradients are used for optimization. Default is True. - model (Optional[nn.Module]): The neural network model. - optimizer (Optional[Union[optim.Optimizer, NGOptimizer]]): The optimizer for training. - config (Optional[TrainConfig]): The configuration settings for training. - train_dataloader (Optional[DataLoader]): DataLoader for training data. - val_dataloader (Optional[DataLoader]): DataLoader for validation data. - test_dataloader (Optional[DataLoader]): DataLoader for testing data. + model (nn.Module): The neural network model. + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. + config (TrainConfig): The configuration settings for training. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. optimize_step (Callable): Function for performing an optimization step. loss_fn (Callable): loss function to use. @@ -232,32 +232,33 @@ class Trainer(BaseTrainer): def __init__( self, model: nn.Module, - optimizer: Union[optim.Optimizer, NGOptimizer, None], + optimizer: optim.Optimizer | NGOptimizer | None, config: TrainConfig, - loss_fn: Union[None, Callable, str], - train_dataloader: DataLoader = None, - val_dataloader: DataLoader = None, - test_dataloader: DataLoader = None, + loss_fn: str | Callable = "mse", + train_dataloader: DataLoader | None = None, + val_dataloader: DataLoader | None = None, + test_dataloader: DataLoader | None = None, optimize_step: Callable = optimize_step, - device: torch_device = None, - dtype: torch_dtype = None, - max_batches: int = None, + device: torch_device | None = None, + dtype: torch_dtype | None = None, + max_batches: int | None = None, ): """ Initializes the Trainer class. Args: model (nn.Module): The PyTorch model to train. - optimizer (optim.Optimizer): The optimizer for training. + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. config (TrainConfig): Training configuration object. - loss_fn (Union[None, Callable, str]): Loss function used for training. - train_dataloader (DataLoader): DataLoader for training data. - val_dataloader (DataLoader): DataLoader for validation data. - test_dataloader (DataLoader): DataLoader for test data. + loss_fn (str | Callable ): Loss function used for training. + If not specified, default mse loss will be used. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for test data. optimize_step (Callable): Function to execute an optimization step. device (torch_device): Device to use for computation. dtype (torch_dtype): Data type for computation. - max_batches (int): Maximum number of batches to process per epoch. + max_batches (int | None): Maximum number of batches to process per epoch. This is only valid in case of finite TensorDataset dataloaders. if max_batches is not None, the maximum number of batches used will be min(max_batches, len(dataloader.dataset)) @@ -277,15 +278,15 @@ def __init__( self.current_epoch: int = 0 self.global_step: int = 0 self.log_device: str = "cpu" if device is None else device - self.device: torch_device = device - self.dtype: torch_dtype = dtype - self.data_dtype: torch_dtype = None + self.device: torch_device | None = device + self.dtype: torch_dtype | None = dtype + self.data_dtype: torch_dtype | None = None if self.dtype: self.data_dtype = float64 if (self.dtype == complex128) else float32 def fit( - self, train_dataloader: DataLoader = None, val_dataloader: DataLoader = None - ) -> Tuple[nn.Module, optim.Optimizer]: + self, train_dataloader: DataLoader | None = None, val_dataloader: DataLoader | None = None + ) -> tuple[nn.Module, optim.Optimizer]: """ Fits the model using the specified training configuration. @@ -293,11 +294,11 @@ def fit( provided in the trainer will be used. Args: - train_dataloader Optional(DataLoader): DataLoader for training data. - val_dataloader Optional(DataLoader): DataLoader for validation data. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. Returns: - Tuple[nn.Module, optim.Optimizer]: The trained model and optimizer. + tuple[nn.Module, optim.Optimizer]: The trained model and optimizer. """ if train_dataloader is not None: self.train_dataloader = train_dataloader @@ -350,14 +351,14 @@ def _fit_end(self) -> None: self.callback_manager.end_training(trainer=self) @BaseTrainer.callback("train") - def _train(self) -> List[List[Tuple[torch.Tensor, Dict[str, Any]]]]: + def _train(self) -> list[list[tuple[torch.Tensor, dict[str, Any]]]]: """ Runs the main training loop, iterating over epochs. Returns: - List[List[Tuple[torch.Tensor, Dict[str, Any]]]]: Training loss + list[list[tuple[torch.Tensor, dict[str, Any]]]]: Training loss metrics for all epochs. - List -> List -> Tuples + list -> list -> tuples Epochs -> Training Batches -> (loss, metrics) """ self.on_train_start() @@ -400,7 +401,7 @@ def _train(self) -> List[List[Tuple[torch.Tensor, Dict[str, Any]]]]: return train_losses @BaseTrainer.callback("train_epoch") - def run_training(self, dataloader: DataLoader) -> List[Tuple[torch.Tensor, Dict[str, Any]]]: + def run_training(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]: """ Runs the training for a single epoch, iterating over multiple batches. @@ -408,8 +409,8 @@ def run_training(self, dataloader: DataLoader) -> List[Tuple[torch.Tensor, Dict[ dataloader (DataLoader): DataLoader for training data. Returns: - List[Tuple[torch.Tensor, Dict[str, Any]]]: Loss and metrics for each batch. - List -> Tuples + list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch. + list -> tuples Training Batches -> (loss, metrics) """ self.model.train() @@ -433,8 +434,8 @@ def run_training(self, dataloader: DataLoader) -> List[Tuple[torch.Tensor, Dict[ @BaseTrainer.callback("train_batch") def run_train_batch( - self, batch: Tuple[torch.Tensor, ...] - ) -> Tuple[torch.Tensor, Dict[str, Any]]: + self, batch: tuple[torch.Tensor, ...] + ) -> tuple[torch.Tensor, dict[str, Any]]: """ Runs a single training batch, performing optimization. @@ -445,11 +446,11 @@ def run_train_batch( update_ng_parameters function. Args: - batch (Tuple[torch.Tensor, ...]): Batch of data from the DataLoader. + batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader. Returns: - Tuple[torch.Tensor, Dict[str, Any]]: Loss and metrics for the batch. - Tuple of (loss, metrics) + tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch. + tuple of (loss, metrics) """ if self.use_grad: @@ -477,7 +478,7 @@ def run_train_batch( return self.modify_batch_end_loss_metrics(loss_metrics) @BaseTrainer.callback("val_epoch") - def run_validation(self, dataloader: DataLoader) -> List[Tuple[torch.Tensor, Dict[str, Any]]]: + def run_validation(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]: """ Runs the validation loop for a single epoch, iterating over multiple batches. @@ -485,8 +486,8 @@ def run_validation(self, dataloader: DataLoader) -> List[Tuple[torch.Tensor, Dic dataloader (DataLoader): DataLoader for validation data. Returns: - List[Tuple[torch.Tensor, Dict[str, Any]]]: Loss and metrics for each batch. - List -> Tuples + list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch. + list -> tuples Validation Batches -> (loss, metrics) """ self.model.eval() @@ -501,23 +502,21 @@ def run_validation(self, dataloader: DataLoader) -> List[Tuple[torch.Tensor, Dic return val_epoch_loss_metrics @BaseTrainer.callback("val_batch") - def run_val_batch(self, batch: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, Dict[str, Any]]: + def run_val_batch(self, batch: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, dict[str, Any]]: """ Runs a single validation batch. Args: - batch (Tuple[torch.Tensor, ...]): Batch of data from the DataLoader. + batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader. Returns: - Tuple[torch.Tensor, Dict[str, Any]]: Loss and metrics for the batch. + tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch. """ with torch.no_grad(): loss_metrics = self.loss_fn(self.model, batch) return self.modify_batch_end_loss_metrics(loss_metrics) - def test( - self, test_dataloader: DataLoader = None - ) -> Optional[List[Tuple[torch.Tensor, Dict[str, Any]]]]: + def test(self, test_dataloader: DataLoader = None) -> list[tuple[torch.Tensor, dict[str, Any]]]: """ Runs the testing loop if a test DataLoader is provided. @@ -528,8 +527,8 @@ def test( test_dataloader (DataLoader): DataLoader for test data. Returns: - Optional[List[Tuple[torch.Tensor, Dict[str, Any]]]]: Loss and metrics for each batch. - List -> Tuples + list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch. + list -> tuples Test Batches -> (loss, metrics) """ if test_dataloader is not None: @@ -548,16 +547,16 @@ def test( @BaseTrainer.callback("test_batch") def run_test_batch( - self, batch: Tuple[torch.Tensor, ...] - ) -> Tuple[torch.Tensor, Dict[str, Any]]: + self, batch: tuple[torch.Tensor, ...] + ) -> tuple[torch.Tensor, dict[str, Any]]: """ Runs a single test batch. Args: - batch (Tuple[torch.Tensor, ...]): Batch of data from the DataLoader. + batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader. Returns: - Tuple[torch.Tensor, Dict[str, Any]]: Loss and metrics for the batch. + tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch. """ with torch.no_grad(): loss_metrics = self.loss_fn(self.model, batch) @@ -567,7 +566,7 @@ def batch_iter( self, dataloader: DataLoader, num_batches: int, - ) -> Iterable[Union[Tuple[torch.Tensor, ...], None]]: + ) -> Iterable[tuple[torch.Tensor, ...] | None]: """ Yields batches from the provided dataloader. @@ -576,7 +575,7 @@ def batch_iter( num_batches (int): The maximum number of batches to yield. Yields: - Union[Tuple[torch.Tensor, ...], None]: A batch from the dataloader moved to the + Iterable[tuple[torch.Tensor, ...] | None]: A batch from the dataloader moved to the specified device and dtype. """ if dataloader is None: @@ -589,8 +588,8 @@ def batch_iter( yield batch def modify_batch_end_loss_metrics( - self, loss_metrics: Tuple[torch.Tensor, Dict[str, Any]] - ) -> Tuple[torch.Tensor, Dict[str, Any]]: + self, loss_metrics: tuple[torch.Tensor, dict[str, Any]] + ) -> tuple[torch.Tensor, dict[str, Any]]: """ Modifies the loss and metrics at the end of batch for proper logging. @@ -599,10 +598,10 @@ def modify_batch_end_loss_metrics( A "{state}_loss" is added to metrics. Args: - loss_metrics (Tuple[torch.Tensor, Dict[str, Any]]): Original loss and metrics. + loss_metrics (tuple[torch.Tensor, dict[str, Any]]): Original loss and metrics. Returns: - Tuple[None | torch.Tensor, Dict[str, Any]]: Modified loss and metrics. + tuple[None | torch.Tensor, dict[str, Any]]: Modified loss and metrics. """ for phase in ["train", "val", "test"]: if phase in self.training_stage: @@ -614,29 +613,28 @@ def modify_batch_end_loss_metrics( def build_optimize_result( self, - result: Union[ - None, - Tuple[torch.Tensor, Dict[Any, Any]], - List[Tuple[torch.Tensor, Dict[Any, Any]]], - List[List[Tuple[torch.Tensor, Dict[Any, Any]]]], - ], + result: None + | tuple[torch.Tensor, dict[Any, Any]] + | list[tuple[torch.Tensor, dict[Any, Any]]] + | list[list[tuple[torch.Tensor, dict[Any, Any]]]], ) -> None: """ Builds and stores the optimization result by calculating the average loss and metrics. Result (or loss_metrics) can have multiple formats: - `None` Indicates no loss or metrics data is provided. - - `Tuple[torch.Tensor, Dict[str, Any]]` A single tuple containing the loss tensor + - `tuple[torch.Tensor, dict[str, Any]]` A single tuple containing the loss tensor and metrics dictionary - at the end of batch. - - `List[Tuple[torch.Tensor, Dict[str, Any]]]` A list of tuples for + - `list[tuple[torch.Tensor, dict[str, Any]]]` A list of tuples for multiple batches. - - `List[List[Tuple[torch.Tensor, Dict[str, Any]]]]` A list of lists of tuples, + - `list[list[tuple[torch.Tensor, dict[str, Any]]]]` A list of lists of tuples, where each inner list represents metrics across multiple batches within an epoch. Args: - result: (Union[None, Tuple[torch.Tensor, Dict[str, Any]], - List[Tuple[torch.Tensor, Dict[str, Any]]], - List[List[Tuple[torch.Tensor, Dict[str, Any]]]]]) + result: (None | + tuple[torch.Tensor, dict[Any, Any]] | + list[tuple[torch.Tensor, dict[Any, Any]]] | + list[list[tuple[torch.Tensor, dict[Any, Any]]]]) The loss and metrics data, which can have multiple formats Returns: @@ -646,20 +644,20 @@ def build_optimize_result( loss_metrics = result if loss_metrics is None: loss = None - metrics: Dict[Any, Any] = {} + metrics: dict[Any, Any] = {} elif isinstance(loss_metrics, tuple): # Single tuple case loss, metrics = loss_metrics else: - last_epoch: List[Tuple[torch.Tensor, Dict[Any, Any]]] = [] + last_epoch: list[tuple[torch.Tensor, dict[Any, Any]]] = [] if isinstance(loss_metrics, list): # Check if it's a list of tuples if all(isinstance(item, tuple) for item in loss_metrics): - last_epoch = cast(List[Tuple[torch.Tensor, Dict[Any, Any]]], loss_metrics) + last_epoch = cast(list[tuple[torch.Tensor, dict[Any, Any]]], loss_metrics) # Check if it's a list of lists of tuples elif all(isinstance(item, list) for item in loss_metrics): last_epoch = cast( - List[Tuple[torch.Tensor, Dict[Any, Any]]], + list[tuple[torch.Tensor, dict[Any, Any]]], loss_metrics[-1] if loss_metrics else [], ) else: