diff --git a/docs/api/ml_tools.md b/docs/api/ml_tools.md index f2f39e6a2..cbc05ea2d 100644 --- a/docs/api/ml_tools.md +++ b/docs/api/ml_tools.md @@ -1,6 +1,9 @@ ## ML Tools -This module implements gradient-free and gradient-based training loops for torch Modules and QuantumModel. It also implements the QNN class. +This module implements a `Trainer` class for torch `Modules` and `QuantumModel`. It also implements the `QNN` class and callbacks that can be used with the trainer module. + + +### ::: qadence.ml_tools.trainer ### ::: qadence.ml_tools.config @@ -8,10 +11,12 @@ This module implements gradient-free and gradient-based training loops for torch ### ::: qadence.ml_tools.optimize_step -### ::: qadence.ml_tools.train_grad - -### ::: qadence.ml_tools.train_no_grad - ### ::: qadence.ml_tools.data ### ::: qadence.ml_tools.models + +### ::: qadence.ml_tools.callbacks.callback + +### ::: qadence.ml_tools.train_utils.base_trainer + +### ::: qadence.ml_tools.callbacks.writer_registry diff --git a/docs/tutorials/advanced_tutorials/custom-models.md b/docs/tutorials/advanced_tutorials/custom-models.md index b0819a66b..3cc6746d5 100644 --- a/docs/tutorials/advanced_tutorials/custom-models.md +++ b/docs/tutorials/advanced_tutorials/custom-models.md @@ -118,7 +118,8 @@ This model can then be trained with the standard Qadence helper functions. ```python exec="on" source="material-block" result="json" session="custom-model" from qadence import run -from qadence.ml_tools import train_with_grad, TrainConfig +from qadence.ml_tools import Trainer, TrainConfig +Trainer.set_use_grad(True) criterion = torch.nn.MSELoss() optimizer = torch.optim.Adam(model.parameters(), lr=1e-1) @@ -128,9 +129,10 @@ def loss_fn(model: LearnHadamard, _unused) -> tuple[torch.Tensor, dict]: return loss, {} config = TrainConfig(max_iter=2500) -model, optimizer = train_with_grad( - model, None, optimizer, config, loss_fn=loss_fn +trainer = Trainer( + model, optimizer, config, loss_fn ) +model, optimizer = trainer.fit() wf_target = run(target_circuit) assert torch.allclose(wf_target, model.wavefunction(), atol=1e-2) diff --git a/docs/tutorials/digital_analog_qc/analog-qubo.md b/docs/tutorials/digital_analog_qc/analog-qubo.md index e4dd848d5..b7e38af48 100644 --- a/docs/tutorials/digital_analog_qc/analog-qubo.md +++ b/docs/tutorials/digital_analog_qc/analog-qubo.md @@ -56,7 +56,7 @@ ensure the reproducibility of this tutorial. import torch from qadence import QuantumModel, QuantumCircuit, Register from qadence import RydbergDevice, AnalogRX, AnalogRZ, chain -from qadence.ml_tools import train_gradient_free, TrainConfig, num_parameters +from qadence.ml_tools import Trainer, TrainConfig, num_parameters import nevergrad as ng import matplotlib.pyplot as plt @@ -80,12 +80,12 @@ Q = np.array( ] ) -def loss(model: QuantumModel, *args) -> tuple[float, dict]: +def loss(model: QuantumModel, *args) -> tuple[torch.Tensor, dict]: to_arr_fn = lambda bitstring: np.array(list(bitstring), dtype=int) cost_fn = lambda arr: arr.T @ Q @ arr samples = model.sample({}, n_shots=1000)[0] # extract samples cost_fn = sum(samples[key] * cost_fn(to_arr_fn(key)) for key in samples) - return cost_fn / sum(samples.values()), {} # We return an optional metrics dict + return torch.tensor(cost_fn / sum(samples.values())), {} # We return an optional metrics dict ``` The QAOA algorithm needs a variational quantum circuit with optimizable parameters. @@ -132,11 +132,14 @@ ML facilities to run gradient-free optimizations using the [`nevergrad`](https://facebookresearch.github.io/nevergrad/) library. ```python exec="on" source="material-block" session="qubo" +Trainer.set_use_grad(False) + config = TrainConfig(max_iter=100) optimizer = ng.optimizers.NGOpt( budget=config.max_iter, parametrization=num_parameters(model) ) -train_gradient_free(model, None, optimizer, config, loss) +trainer = Trainer(model, optimizer, config, loss) +trainer.fit() optimal_counts = model.sample({}, n_shots=1000)[0] print(f"optimal_count = {optimal_counts}") # markdown-exec: hide diff --git a/docs/tutorials/qml/dqc_1d.md b/docs/tutorials/qml/dqc_1d.md index 917df73d3..75f417c3f 100644 --- a/docs/tutorials/qml/dqc_1d.md +++ b/docs/tutorials/qml/dqc_1d.md @@ -112,7 +112,7 @@ print(html_string(circuit)) # markdown-exec: hide ## Training the model -Now that the model is defined we can proceed with the training. the `QNN` class can be used like any other `torch.nn.Module`. Here we write a simple training loop, but you can also look at the [ml tools tutorial](ml_tools.md) to use the convenience training functions that Qadence provides. +Now that the model is defined we can proceed with the training. the `QNN` class can be used like any other `torch.nn.Module`. Here we write a simple training loop, but you can also look at the [ml tools tutorial](ml_tools/trainer.md) to use the convenience training functions that Qadence provides. To train the model, we will select a random set of collocation points uniformly distributed within $-1.0< x <1.0$ and compute the loss function for those points. diff --git a/docs/tutorials/qml/index.md b/docs/tutorials/qml/index.md index 5bf073dd3..5b68da7c6 100644 --- a/docs/tutorials/qml/index.md +++ b/docs/tutorials/qml/index.md @@ -6,7 +6,7 @@ differentiation via integration with [PyTorch](https://pytorch.org/) deep learni Furthermore, Qadence offers a wide range of utilities for helping building and researching quantum machine learning algorithms, including: * [a set of constructors](../../content/qml_constructors.md) for circuits commonly used in quantum machine learning such as feature maps and ansatze -* [a set of tools](ml_tools.md) for training and optimizing quantum neural networks and loading classical data into a QML algorithm +* [a set of tools](ml_tools/trainer.md) for training and optimizing quantum neural networks and loading classical data into a QML algorithm ## Some simple examples diff --git a/docs/tutorials/qml/ml_tools.md b/docs/tutorials/qml/ml_tools.md deleted file mode 100644 index d9efcffe6..000000000 --- a/docs/tutorials/qml/ml_tools.md +++ /dev/null @@ -1,443 +0,0 @@ -## Dataloaders - -When using Qadence, you can supply classical data to a quantum machine learning -algorithm by using a standard PyTorch `DataLoader` instance. Qadence also provides -the `DictDataLoader` convenience class which allows -to build dictionaries of `DataLoader`s instances and easily iterate over them. - -```python exec="on" source="material-block" result="json" -import torch -from torch.utils.data import DataLoader, TensorDataset -from qadence.ml_tools import DictDataLoader, to_dataloader - - -def dataloader(data_size: int = 25, batch_size: int = 5, infinite: bool = False) -> DataLoader: - x = torch.linspace(0, 1, data_size).reshape(-1, 1) - y = torch.sin(x) - return to_dataloader(x, y, batch_size=batch_size, infinite=infinite) - - -def dictdataloader(data_size: int = 25, batch_size: int = 5) -> DictDataLoader: - dls = {} - for k in ["y1", "y2"]: - x = torch.rand(data_size, 1) - y = torch.sin(x) - dls[k] = to_dataloader(x, y, batch_size=batch_size, infinite=True) - return DictDataLoader(dls) - - -# iterate over standard DataLoader -for (x,y) in dataloader(data_size=6, batch_size=2): - print(f"Standard {x = }") - -# construct an infinite dataset which will keep sampling indefinitely -n_epochs = 5 -dl = iter(dataloader(data_size=6, batch_size=2, infinite=True)) -for _ in range(n_epochs): - (x, y) = next(dl) - print(f"Infinite {x = }") - -# iterate over DictDataLoader -ddl = dictdataloader() -data = next(iter(ddl)) -print(f"{data = }") -``` - -## Optimization routines - -For training QML models, Qadence also offers a few out-of-the-box routines for optimizing differentiable -models, _e.g._ `QNN`s and `QuantumModel`, containing either *trainable* and/or *non-trainable* parameters -(see [the parameters tutorial](../../content/parameters.md) for detailed information about parameter types): - -* [`train_with_grad`][qadence.ml_tools.train_with_grad] for gradient-based optimization using PyTorch native optimizers -* [`train_gradient_free`][qadence.ml_tools.train_gradient_free] for gradient-free optimization using -the [Nevergrad](https://facebookresearch.github.io/nevergrad/) library. - -These routines performs training, logging/printing loss metrics and storing intermediate checkpoints of models. In the following, we -use `train_with_grad` as example but the code can be used directly with the gradient-free routine. - -As every other training routine commonly used in Machine Learning, it requires -`model`, `data` and an `optimizer` as input arguments. -However, in addition, it requires a `loss_fn` and a `TrainConfig`. -A `loss_fn` is required to be a function which expects both a model and data and returns a tuple of (loss, metrics: ``, ...), where `metrics` is a dict of scalars which can be customized too. It can optionally also return additional values which are utilised by the corresponding user-provided `optimize_step` function inside `train_with_grad`. - -```python exec="on" source="material-block" -import torch -from itertools import count -cnt = count() -criterion = torch.nn.MSELoss() - -def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: - next(cnt) - x, y = data[0], data[1] - out = model(x) - loss = criterion(out, y) - return loss, {} - -``` - -The [`TrainConfig`][qadence.ml_tools.config.TrainConfig] tells `train_with_grad` what batch_size should be used, -how many epochs to train, in which intervals to print/log metrics and how often to store intermediate checkpoints. -It is also possible to provide custom callback functions by instantiating a [`Callback`][qadence.ml_tools.config.Callback] -with a function `callback` that only accept as argument an instance of [`OptimizeResult`][qadence.ml_tools.data.OptimizeResult] created within the `train` functions. -One can also provide a `callback_condition` function, also only accepting an instance of [`OptimizeResult`][qadence.ml_tools.data.OptimizeResult], which returns True if `callback` should be called. If no `callback_condition` is provided, `callback` is called at every x epochs (specified by `Callback`'s `called_every` argument). We can also specify in which part of the training function the `Callback` will be applied. Note that if you need it, you can modify the instance of [`OptimizeResult`][qadence.ml_tools.data.OptimizeResult] created in the train functions by specifying `Callback`'s `modify_optimize_result` (as a function or a dictionary). For instance, we could add inputs to the `extra` field of [`OptimizeResult`][qadence.ml_tools.data.OptimizeResult] to be used within `callback`. An example code is shown below. - -```python exec="on" source="material-block" -from qadence.ml_tools import OptimizeResult, TrainConfig, Callback - -batch_size = 5 -n_epochs = 100 - -print_parameters = lambda opt_res: print(opt_res.model.parameters()) -condition_print = lambda opt_res: opt_res.loss < 1.0e-03 -modify_extra_opt_res = {"n_epochs": n_epochs} -custom_callback = Callback(print_parameters, callback_condition=condition_print, modify_optimize_result=modify_extra_opt_res, called_every=10, call_end_epoch=True) - -config = TrainConfig( - folder="some_path/", - max_iter=n_epochs, - checkpoint_every=100, - write_every=100, - batch_size=batch_size, - callbacks = [custom_callback] -) -``` - -If it is desired to only the save the "best" checkpoint, the following must be ensured: - -(a) `checkpoint_best_only = True` is used while creating the configuration through `TrainConfig`, -(b) `val_every` is set to a valid integer value (for example, `val_every = 10`) which controls the no. of iterations after which the validation data should be used to evaluate the model during training, which can also be set through `TrainConfig`, -(c) a validation criterion is provided through the `validation_criterion`, set through `TrainConfig` to quantify the definition of "best", and -(d) the dataloader passed to `train_grad` is of type `DictDataLoader`. In this case, it is expected that a validation dataloader is also provided along with the train dataloader since the validation data will be used to decide the "best" checkpoint. The dataloaders must be accessible with specific keys: "train" and "val". - -The criterion used to decide the "best" checkpoint can be customized by `validation_criterion`, which should be a function that can take any number of arguments and return a boolean value (True or False) indicating whether some validation metric is satisfied or not. Typical choices are to return True when the validation loss (accuracy) has decreased (increased) compared to smallest (largest) value from previous iterations at which a validation check was performed. - -Let's see it in action with a simple example. - -### Fitting a funtion with a QNN using `ml_tools` - -In Quantum Machine Learning, the general consensus is to use `complex128` precision for states and operators and `float64` precision for parameters. This is also the convention which is used in `qadence`. -However, for specific usecases, lower precision can greatly speed up training and reduce memory consumption. When using the `pyqtorch` backend, `qadence` offers the option to move a `QuantumModel` instance to a specific precision using the torch `to` syntax. - -Let's look at a complete example of how to use `train_with_grad` now. Here we perform a validation check during training and use a validation criterion that checks whether the validation loss in the current iteration has decreased compared to the lowest validation loss from all previous iterations. For demonstration, the train and the validation data are kept the same here. However, it is beneficial and encouraged to keep them distinct in practice to understand model's generalization capabilities. - -```python exec="on" source="material-block" html="1" -from pathlib import Path -import torch -from functools import reduce -from operator import add -from itertools import count -import matplotlib.pyplot as plt - -from qadence import Parameter, QuantumCircuit, Z -from qadence import hamiltonian_factory, hea, feature_map, chain -from qadence import QNN -from qadence.ml_tools import TrainConfig, train_with_grad, to_dataloader, DictDataLoader - -DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') -DTYPE = torch.complex64 -n_qubits = 4 -fm = feature_map(n_qubits) -ansatz = hea(n_qubits=n_qubits, depth=3) -observable = hamiltonian_factory(n_qubits, detuning=Z) -circuit = QuantumCircuit(n_qubits, fm, ansatz) - -model = QNN(circuit, observable, backend="pyqtorch", diff_mode="ad") -batch_size = 100 -input_values = {"phi": torch.rand(batch_size, requires_grad=True)} -pred = model(input_values) - -cnt = count() -criterion = torch.nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr=0.1) - -def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: - next(cnt) - x, y = data[0], data[1] - out = model(x) - loss = criterion(out, y) - return loss, {} - -def validation_criterion( - current_validation_loss: float, current_best_validation_loss: float, val_epsilon: float -) -> bool: - return current_validation_loss <= current_best_validation_loss - val_epsilon - -n_epochs = 300 - -config = TrainConfig( - max_iter=n_epochs, - batch_size=batch_size, - checkpoint_best_only=True, - val_every=10, # The model will be run on the validation data after every `val_every` epochs. - validation_criterion=validation_criterion -) - -fn = lambda x, degree: .05 * reduce(add, (torch.cos(i*x) + torch.sin(i*x) for i in range(degree)), 0.) -x = torch.linspace(0, 10, batch_size, dtype=torch.float32).reshape(-1, 1) -y = fn(x, 5) - -data = DictDataLoader( - { - "train": to_dataloader(x, y, batch_size=batch_size, infinite=True), - "val": to_dataloader(x, y, batch_size=batch_size, infinite=True), - } -) - -train_with_grad(model, data, optimizer, config, loss_fn=loss_fn,device=DEVICE, dtype=DTYPE) - -plt.clf() -plt.plot(x.numpy(), y.numpy(), label='truth') -plt.plot(x.numpy(), model(x).detach().numpy(), "--", label="final", linewidth=3) -plt.legend() -from docs import docsutils # markdown-exec: hide -print(docsutils.fig_to_html(plt.gcf())) # markdown-exec: hide -``` - -For users who want to use the low-level API of `qadence`, here an example written without `train_with_grad`. - -### Fitting a function - Low-level API - -```python exec="on" source="material-block" -from pathlib import Path -import torch -from itertools import count -from qadence.constructors import hamiltonian_factory, hea, feature_map -from qadence import chain, Parameter, QuantumCircuit, Z -from qadence import QNN -from qadence.ml_tools import TrainConfig - -n_qubits = 2 -fm = feature_map(n_qubits) -ansatz = hea(n_qubits=n_qubits, depth=3) -observable = hamiltonian_factory(n_qubits, detuning=Z) -circuit = QuantumCircuit(n_qubits, fm, ansatz) - -model = QNN(circuit, observable, backend="pyqtorch", diff_mode="ad") -batch_size = 1 -input_values = {"phi": torch.rand(batch_size, requires_grad=True)} -pred = model(input_values) - -criterion = torch.nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr=0.1) -n_epochs=50 -cnt = count() - -tmp_path = Path("/tmp") - -config = TrainConfig( - folder=tmp_path, - max_iter=n_epochs, - checkpoint_every=100, - write_every=100, - batch_size=batch_size, -) - -x = torch.linspace(0, 1, batch_size).reshape(-1, 1) -y = torch.sin(x) - -for i in range(n_epochs): - out = model(x) - loss = criterion(out, y) - loss.backward() - optimizer.step() -``` - - -## Custom `train` loop - -If you need custom training functionality that goes beyond what is available in -`qadence.ml_tools.train_with_grad` and `qadence.ml_tools.train_gradient_free` you can write your own -training loop based on the building blocks that are available in Qadence. - -A simplified version of Qadence's train loop is defined below. Feel free to copy it and modify at -will. - -```python -from typing import Callable, Union - -from torch.nn import Module -from torch.optim import Optimizer -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter - -from qadence.ml_tools.config import TrainConfig -from qadence.ml_tools.data import DictDataLoader, data_to_device -from qadence.ml_tools.optimize_step import optimize_step -from qadence.ml_tools.printing import print_metrics, write_tensorboard -from qadence.ml_tools.saveload import load_checkpoint, write_checkpoint - - -def train( - model: Module, - data: DataLoader, - optimizer: Optimizer, - config: TrainConfig, - loss_fn: Callable, - device: str = "cpu", - optimize_step: Callable = optimize_step, - write_tensorboard: Callable = write_tensorboard, -) -> tuple[Module, Optimizer]: - - # Move model to device before optimizer is loaded - model = model.to(device) - - # load available checkpoint - init_iter = 0 - if config.folder: - model, optimizer, init_iter = load_checkpoint(config.folder, model, optimizer) - - # initialize tensorboard - writer = SummaryWriter(config.folder, purge_step=init_iter) - - dl_iter = iter(dataloader) - - # outer epoch loop - for iteration in range(init_iter, init_iter + config.max_iter): - data = data_to_device(next(dl_iter), device) - loss, metrics = optimize_step(model, optimizer, loss_fn, data) - - if iteration % config.print_every == 0 and config.verbose: - print_metrics(loss, metrics, iteration) - - if iteration % config.write_every == 0: - write_tensorboard(writer, loss, metrics, iteration) - - if config.folder: - if iteration % config.checkpoint_every == 0: - write_checkpoint(config.folder, model, optimizer, iteration) - - # Final writing and checkpointing - if config.folder: - write_checkpoint(config.folder, model, optimizer, iteration) - write_tensorboard(writer, loss, metrics, iteration) - writer.close() - - return model, optimizer -``` - -## Experiment tracking with mlflow - -Qadence allows to track runs and log hyperparameters, models and plots with [tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) and [mlflow](https://mlflow.org/). In the following, we demonstrate the integration with mlflow. - -### mlflow configuration -We have control over our tracking configuration by setting environment variables. First, let's look at the tracking URI. For the purpose of this demo we will be working with a local database, in a similar fashion as described [here](https://mlflow.org/docs/latest/tracking/tutorials/local-database.html), -```bash -export MLFLOW_TRACKING_URI=sqlite:///mlruns.db -``` - -Qadence can also read the following two environment variables to define the mlflow experiment name and run name -```bash -export MLFLOW_EXPERIMENT=test_experiment -export MLFLOW_RUN_NAME=run_0 -``` - -If no tracking URI is provided, mlflow stores run information and artifacts in the local `./mlflow` directory and if no names are defined, the experiment and run will be named with random UUIDs. - -### Setup -Let's do the necessary imports and declare a `DataLoader`. We can already define some hyperparameters here, including the seed for random number generators. mlflow can log hyperparameters with arbitrary types, for example the observable that we want to monitor (`Z` in this case, which has a `qadence.Operation` type). - -```python -import random -from itertools import count - -import numpy as np -import torch -from matplotlib import pyplot as plt -from matplotlib.figure import Figure -from torch.nn import Module -from torch.utils.data import DataLoader - -from qadence import hea, QuantumCircuit, Z -from qadence.constructors import feature_map, hamiltonian_factory -from qadence.ml_tools import train_with_grad, TrainConfig -from qadence.ml_tools.data import to_dataloader -from qadence.ml_tools.utils import rand_featureparameters -from qadence.models import QNN, QuantumModel -from qadence.types import ExperimentTrackingTool - -hyperparams = { - "seed": 42, - "batch_size": 10, - "n_qubits": 2, - "ansatz_depth": 1, - "observable": Z, -} - -np.random.seed(hyperparams["seed"]) -torch.manual_seed(hyperparams["seed"]) -random.seed(hyperparams["seed"]) - - -def dataloader(batch_size: int = 25) -> DataLoader: - x = torch.linspace(0, 1, batch_size).reshape(-1, 1) - y = torch.cos(x) - return to_dataloader(x, y, batch_size=batch_size, infinite=True) -``` - -We continue with the regular QNN definition, together with the loss function and optimizer. - -```python -obs = hamiltonian_factory(register=hyperparams["n_qubits"], detuning=hyperparams["observable"]) - -data = dataloader(hyperparams["batch_size"]) -fm = feature_map(hyperparams["n_qubits"], param="x") - -model = QNN( - QuantumCircuit( - hyperparams["n_qubits"], fm, hea(hyperparams["n_qubits"], hyperparams["ansatz_depth"]) - ), - observable=obs, - inputs=["x"], -) - -cnt = count() -criterion = torch.nn.MSELoss() -optimizer = torch.optim.Adam(model.parameters(), lr=0.1) - -inputs = rand_featureparameters(model, 1) - -def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict]: - next(cnt) - out = model.expectation(inputs) - loss = criterion(out, torch.rand(1)) - return loss, {} -``` - -### `TrainConfig` specifications -Qadence offers different tracking options via `TrainConfig`. Here we use the `ExperimentTrackingTool` type to specify that we want to track the experiment with mlflow. Tracking with tensorboard is also possible. We can then indicate *what* and *how often* we want to track or log. `write_every` controls the number of epochs after which the loss values is logged. Thanks to the `plotting_functions` and `plot_every`arguments, we are also able to plot model-related quantities throughout training. Notice that arbitrary plotting functions can be passed, as long as the signature is the same as `plot_fn` below. Finally, the trained model can be logged by setting `log_model=True`. Here is an example of plotting function and training configuration - -```python -def plot_fn(model: Module, iteration: int) -> tuple[str, Figure]: - descr = f"ufa_prediction_epoch_{iteration}.png" - fig, ax = plt.subplots() - x = torch.linspace(0, 1, 100).reshape(-1, 1) - out = model.expectation(x) - ax.plot(x.detach().numpy(), out.detach().numpy()) - return descr, fig - - -config = TrainConfig( - folder="mlflow_demonstration", - max_iter=10, - checkpoint_every=1, - plot_every=2, - write_every=1, - log_model=True, - tracking_tool=ExperimentTrackingTool.MLFLOW, - hyperparams=hyperparams, - plotting_functions=(plot_fn,), -) -``` - -### Training and inspecting -Model training happens as usual -```python -train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) -``` - -After training , we can inspect our experiment via the mlflow UI -```bash -mlflow ui --port 8080 --backend-store-uri sqlite:///mlruns.db -``` -In this case, since we're running on a local server, we can access the mlflow UI by navigating to http://localhost:8080/. diff --git a/docs/tutorials/qml/ml_tools/callbacks.md b/docs/tutorials/qml/ml_tools/callbacks.md new file mode 100644 index 000000000..383af0bf6 --- /dev/null +++ b/docs/tutorials/qml/ml_tools/callbacks.md @@ -0,0 +1,237 @@ + +# Callbacks for Trainer + +Qadence `ml_tools` provides a powerful callback system for customizing various stages of the training process. With callbacks, you can monitor, log, save, and alter your training workflow efficiently. A `CallbackManager` is used with [`Trainer`][qadence.ml_tools.Trainer] to execute the training process with defined callbacks. Following default callbacks are already provided in the [`Trainer`][qadence.ml_tools.Trainer]. + +### Default Callbacks +Below is a list of the default callbacks already implemented in the `CallbackManager` used with [`Trainer`][qadence.ml_tools.Trainer]: + +- **`train_start`**: `PlotMetrics`, `SaveCheckpoint`, `WriteMetrics` +- **`train_epoch_end`**: `SaveCheckpoint`, `PrintMetrics`, `PlotMetrics`, `WriteMetrics` +- **`val_epoch_end`**: `SaveBestCheckpoint`, `WriteMetrics` +- **`train_end`**: `LogHyperparameters`, `LogModelTracker`, `WriteMetrics`, `SaveCheckpoint`, `PlotMetrics` + +This guide covers how to define and use callbacks in `TrainConfig`, integrate them with the `Trainer` class, and create custom callbacks using hooks. + + +## 1. Built-in Callbacks + +Qadence ml_tools offers several built-in callbacks for common tasks like saving checkpoints, logging metrics, and tracking models. Below is an overview of each. + +### 1.1. `PrintMetrics` + +Prints metrics at specified intervals. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import PrintMetrics + +print_metrics_callback = PrintMetrics(on="val_batch_end", called_every=100) + +config = TrainConfig( + max_iter=10000, + callbacks=[print_metrics_callback] +) +``` + +### 1.2. `WriteMetrics` + +Writes metrics to a specified logging destination. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import WriteMetrics + +write_metrics_callback = WriteMetrics(on="train_epoch_end", called_every=50) + +config = TrainConfig( + max_iter=5000, + callbacks=[write_metrics_callback] +) +``` + +### 1.3. `PlotMetrics` + +Plots metrics based on user-defined plotting functions. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import PlotMetrics + +plot_metrics_callback = PlotMetrics(on="train_epoch_end", called_every=100) + +config = TrainConfig( + max_iter=5000, + callbacks=[plot_metrics_callback] +) +``` + +### 1.4. `LogHyperparameters` + +Logs hyperparameters to keep track of training settings. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import LogHyperparameters + +log_hyper_callback = LogHyperparameters(on="train_start", called_every=1) + +config = TrainConfig( + max_iter=1000, + callbacks=[log_hyper_callback] +) +``` + +### 1.5. `SaveCheckpoint` + +Saves model checkpoints at specified intervals. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import SaveCheckpoint + +save_checkpoint_callback = SaveCheckpoint(on="train_epoch_end", called_every=100) + +config = TrainConfig( + max_iter=10000, + callbacks=[save_checkpoint_callback] +) +``` + +### 1.6. `SaveBestCheckpoint` + +Saves the best model checkpoint based on a validation criterion. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import SaveBestCheckpoint + +save_best_checkpoint_callback = SaveBestCheckpoint(on="val_epoch_end", called_every=10) + +config = TrainConfig( + max_iter=10000, + callbacks=[save_best_checkpoint_callback] +) +``` + +### 1.7. `LoadCheckpoint` + +Loads a saved model checkpoint at the start of training. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import LoadCheckpoint + +load_checkpoint_callback = LoadCheckpoint(on="train_start") + +config = TrainConfig( + max_iter=10000, + callbacks=[load_checkpoint_callback] +) +``` + +### 1.8. `LogModelTracker` + +Logs the model structure and parameters. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import LogModelTracker + +log_model_callback = LogModelTracker(on="train_end") + +config = TrainConfig( + max_iter=1000, + callbacks=[log_model_callback] +) +``` + + +## 2. Custom Callbacks + +The base `Callback` class in Qadence allows defining custom behavior that can be triggered at specified events (e.g., start of training, end of epoch). You can set parameters such as when the callback runs (`on`), frequency of execution (`called_every`), and optionally define a `callback_condition`. + +### Defining Callbacks + +There are two main ways to define a callback: +1. **Directly providing a function** in the `Callback` instance. +2. **Subclassing** the `Callback` class and implementing custom logic. + +#### Example 1: Providing a Callback Function Directly + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools.callbacks import Callback + +# Define a custom callback function +def custom_callback_function(trainer, config, writer): + print("Executing custom callback.") + +# Create the callback instance +custom_callback = Callback( + on="train_end", + callback=custom_callback_function +) +``` + +#### Example 2: Subclassing the Callback + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools.callbacks import Callback + +class CustomCallback(Callback): + def run_callback(self, trainer, config, writer): + print("Custom behavior in run_callback method.") + +# Create the subclassed callback instance +custom_callback = CustomCallback(on="train_batch_end", called_every=10) +``` + + +## 3. Adding Callbacks to `TrainConfig` + +To use callbacks in `TrainConfig`, add them to the `callbacks` list when configuring the training process. + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools import TrainConfig +from qadence.ml_tools.callbacks import SaveCheckpoint, PrintMetrics + +config = TrainConfig( + max_iter=10000, + callbacks=[ + SaveCheckpoint(on="val_epoch_end", called_every=50), + PrintMetrics(on="train_epoch_end", called_every=100), + ] +) +``` + +## 4. Using Callbacks with `Trainer` + +The `Trainer` class in `qadence.ml_tools` provides built-in support for executing callbacks at various stages in the training process, managed through a callback manager. By default, several callbacks are added to specific hooks to automate common tasks, such as check-pointing, metric logging, and model tracking. + +### Default Callbacks +Below is a list of the default callbacks and their assigned hooks: + +- **`train_start`**: `PlotMetrics`, `SaveCheckpoint`, `WriteMetrics` +- **`train_epoch_end`**: `SaveCheckpoint`, `PrintMetrics`, `PlotMetrics`, `WriteMetrics` +- **`val_epoch_end`**: `SaveBestCheckpoint`, `WriteMetrics` +- **`train_end`**: `LogHyperparameters`, `LogModelTracker`, `WriteMetrics`, `SaveCheckpoint`, `PlotMetrics` + +These defaults handle common needs, but you can also add custom callbacks to any hook. + +### Example: Adding a Custom Callback + +To create a custom `Trainer` that includes a `PrintMetrics` callback executed specifically at the end of each epoch, follow the steps below. + + +```python exec="on" source="material-block" html="1" +from qadence.ml_tools.trainer import Trainer +from qadence.ml_tools.callbacks import PrintMetrics + +class CustomTrainer(Trainer): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.print_metrics_callback = PrintMetrics(on="train_epoch_end", called_every = 10) + + def on_train_epoch_end(self, train_epoch_loss_metrics): + self.print_metrics_callback.run_callback(self) +``` diff --git a/docs/tutorials/qml/ml_tools/data_and_config.md b/docs/tutorials/qml/ml_tools/data_and_config.md new file mode 100644 index 000000000..1a83b4305 --- /dev/null +++ b/docs/tutorials/qml/ml_tools/data_and_config.md @@ -0,0 +1,232 @@ +## 1. Dataloaders + +When using Qadence, you can supply classical data to a quantum machine learning +algorithm by using a standard PyTorch `DataLoader` instance. Qadence also provides +the `DictDataLoader` convenience class which allows +to build dictionaries of `DataLoader`s instances and easily iterate over them. + +```python exec="on" source="material-block" result="json" +import torch +from torch.utils.data import DataLoader, TensorDataset +from qadence.ml_tools import DictDataLoader, to_dataloader + + +def dataloader(data_size: int = 25, batch_size: int = 5, infinite: bool = False) -> DataLoader: + x = torch.linspace(0, 1, data_size).reshape(-1, 1) + y = torch.sin(x) + return to_dataloader(x, y, batch_size=batch_size, infinite=infinite) + + +def dictdataloader(data_size: int = 25, batch_size: int = 5) -> DictDataLoader: + dls = {} + for k in ["y1", "y2"]: + x = torch.rand(data_size, 1) + y = torch.sin(x) + dls[k] = to_dataloader(x, y, batch_size=batch_size, infinite=True) + return DictDataLoader(dls) + + +# iterate over standard DataLoader +for (x,y) in dataloader(data_size=6, batch_size=2): + print(f"Standard {x = }") + +# construct an infinite dataset which will keep sampling indefinitely +n_epochs = 5 +dl = iter(dataloader(data_size=6, batch_size=2, infinite=True)) +for _ in range(n_epochs): + (x, y) = next(dl) + print(f"Infinite {x = }") + +# iterate over DictDataLoader +ddl = dictdataloader() +data = next(iter(ddl)) +print(f"{data = }") +``` + +Note: + In case of `infinite`=True, the dataloader iterator will provide a random sample from the dataset. + +## 2. Training Configuration + +The [`TrainConfig`][qadence.ml_tools.config.TrainConfig] class provides a comprehensive configuration setup for training quantam machine learning models in Qadence. This configuration includes settings for batch size, logging, check-pointing, validation, and additional custom callbacks that control the training process's granularity and flexibility. + +The [`TrainConfig`][qadence.ml_tools.config.TrainConfig] tells [`Trainer`][qadence.ml_tools.Trainer] what batch_size should be used, how many epochs to train, in which intervals to print/log metrics and how often to store intermediate checkpoints. +It is also possible to provide custom callback functions by instantiating a [`Callback`][qadence.ml_tools.callbacks.Callback] +with a function `callback`. + +For example of how to use the TrainConfig with `Trainer`, please see [Examples in Trainer](/trainer.md) + + +### 2.1 Explanation of `TrainConfig` Attributes + +| Attribute | Type | Default | Description | +|--------------------------|--------------------------|--------------------------|-------------| +| `max_iter` | `int` | `10000` | Total number of training epochs. | +| `batch_size` | `int` | `1` | Batch size for training. | +| `print_every` | `int` | `0` | Frequency of console output. Set to `0` to disable. | +| `write_every` | `int` | `0` | Frequency of logging metrics. Set to `0` to disable. | +| `plot_every` | `int` | `0` | Frequency of plotting metrics. Set to `0` to disable. | +| `checkpoint_every` | `int` | `0` | Frequency of saving checkpoints. Set to `0` to disable. | +| `val_every` | `int` | `0` | Frequency of validation checks. Set to `0` to disable. | +| `val_epsilon` | `float` | `1e-5` | Threshold for validation improvement. | +| `validation_criterion` | `Callable` | `None` | Function for validating metric improvement. | +| `trainstop_criterion` | `Callable` | `None` | Function to stop training early. | +| `callbacks` | `list[Callback]` | `[]` | List of custom callbacks. | +| `root_folder` | `Path` | `"./qml_logs"` | Root directory for saving logs and checkpoints. | +| `log_folder` | `Path` | `"./qml_logs"` | Logging directory for saving logs and checkpoints. | +| `log_model` | `bool` | `False` | Enables model logging. | +| `verbose` | `bool` | `True` | Enables detailed logging. | +| `tracking_tool` | `ExperimentTrackingTool` | `TENSORBOARD` | Tool for tracking training metrics. | +| `plotting_functions` | `tuple` | `()` | Functions for plotting metrics. | + + +```python exec="on" source="material-block" +from qadence.ml_tools import OptimizeResult, TrainConfig +from qadence.ml_tools.callbacks import Callback + +batch_size = 5 +n_epochs = 100 + +print_parameters = lambda opt_res: print(opt_res.model.parameters()) +condition_print = lambda opt_res: opt_res.loss < 1.0e-03 +modify_extra_opt_res = {"n_epochs": n_epochs} +custom_callback = Callback(on="train_end", callback = print_parameters, callback_condition=condition_print, modify_optimize_result=modify_extra_opt_res, called_every=10,) + +config = TrainConfig( + root_folder="some_path/", + max_iter=n_epochs, + checkpoint_every=100, + write_every=100, + batch_size=batch_size, + callbacks = [custom_callback] +) +``` + + +### 2.2 Key Configuration Options in `TrainConfig` + +#### Iterations and Batch Size + +- `max_iter` (**int**): Specifies the total number of training iterations (epochs). For an `InfiniteTensorDataset`, each epoch contains one batch; for a `TensorDataset`, it contains `len(dataloader)` batches. +- `batch_size` (**int**): Defines the number of samples processed in each training iteration. + +Example: +```python +config = TrainConfig(max_iter=2000, batch_size=32) +``` + +#### Training Parameters + +- `print_every` (**int**): Controls how often loss and metrics are printed to the console. +- `write_every` (**int**): Determines how frequently metrics are written to the tracking tool, such as TensorBoard or MLflow. +- `checkpoint_every` (**int**): Sets the frequency for saving model checkpoints. + +Note: Set 0 to diable. + +Example: +```python +config = TrainConfig(print_every=100, write_every=50, checkpoint_every=50) +``` + +The user can provide either the `root_folder` or the `log_folder` for saving checkpoints and logging. When neither are provided, the default `root_folder` "./qml_logs" is used. + +- `root_folder` (**Path**): The root directory for saving checkpoints and logs. All training logs will be saved inside a subfolder in this root directory. (The path to these subfolders can be accessed using config._subfolders, and the current logging folder is config.log_folder) +- `create_subfolder_per_run` (**bool**): Creates a unique subfolder for each training run within the specified folder. +- `tracking_tool` (**ExperimentTrackingTool**): Specifies the tracking tool to log metrics, e.g., TensorBoard or MLflow. +- `log_model` (**bool**): Enables logging of a serialized version of the model, which is useful for model versioning. Thi happens at the end of training. + +Note + - The user can also provide `log_folder` argument - which will only be used when `create_subfolder_per_run` = False. + - `log_folder` (**Path**): The log folder used for saving checkpoints and logs. + +Example: +```python +config = TrainConfig(root_folder="path/to/checkpoints", tracking_tool=ExperimentTrackingTool.MLFLOW, checkpoint_best_only=True) +``` + +#### Validation Parameters + +- `checkpoint_best_only` (**bool**): If set to `True`, saves checkpoints only when there is an improvement in the validation metric. +- `val_every` (**int**): Frequency of validation checks. Setting this to `0` disables validation. +- `val_epsilon` (**float**): A small threshold used to compare the current validation loss with previous best losses. +- `validation_criterion` (**Callable**): A custom function to assess if the validation metric meets a specified condition. + +Example: +```python +config = TrainConfig(val_every=200, checkpoint_best_only = True, validation_criterion=lambda current, best: current < best - 0.001) +``` + +If it is desired to only the save the "best" checkpoint, the following must be ensured: + + (a) `checkpoint_best_only = True` is used while creating the configuration through `TrainConfig`, + (b) `val_every` is set to a valid integer value (for example, `val_every = 10`) which controls the no. of iterations after which the validation data should be used to evaluate the model during training, which can also be set through `TrainConfig`, + (c) a validation criterion is provided through the `validation_criterion`, set through `TrainConfig` to quantify the definition of "best", and + (d) the validation dataloader passed to `Trainer` is of type `DataLoader`. In this case, it is expected that a validation dataloader is also provided along with the train dataloader since the validation data will be used to decide the "best" checkpoint. + +The criterion used to decide the "best" checkpoint can be customized by `validation_criterion`, which should be a function that can take val_loss, best_loss, and val_epsilon arguments and return a boolean value (True or False) indicating whether some validation metric is satisfied or not. An example of a simple `validation_criterion` is: +```python +def validation_criterion(val_loss: float, best_val_loss: float, val_epsilon: float) -> bool: + return val_loss < (best_val_loss - val_epsilon) +``` + +#### Custom Callbacks + +`TrainConfig` supports custom callbacks that can be triggered at specific stages of training. The `callbacks` attribute accepts a list of callback instances, which allow for custom behaviors like early stopping or additional logging. +See [Callbacks](/callbacks.md) for more details. + +- `callbacks` (**list[Callback]**): List of custom callbacks to execute during training. + +Example: +```python +from qadence.ml_tools.callbacks import Callback + +def callback_fn(trainer, config, writer): + if trainer.opt_res.loss < 0.001: + print("Custom Callback: Loss threshold reached!") + +custom_callback = Callback(on = "train_epoch_end", called_every = 10, callback_function = callback_fn ) + +config = TrainConfig(callbacks=[custom_callback]) +``` + +#### Hyperparameters and Plotting + +- `hyperparams` (**dict**): A dictionary of hyperparameters (e.g., learning rate, regularization) to be tracked by the tracking tool. +- `plot_every` (**int**): Determines how frequently plots are saved to the tracking tool, such as TensorBoard or MLflow. +- `plotting_functions` (**tuple[LoggablePlotFunction, ...]**): Functions for in-training plotting of metrics or model state. + +Note: Please ensure that plotting_functions are provided when plot_every > 0 + +Example: +```python +config = TrainConfig( + plot_every=10, + hyperparams={"learning_rate": 0.001, "batch_size": 32}, + plotting_functions=(plot_loss_function,) +) +``` + + + + + + + + +## 3. Experiment tracking with mlflow + +Qadence allows to track runs and log hyperparameters, models and plots with [tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) and [mlflow](https://mlflow.org/). In the following, we demonstrate the integration with mlflow. + +### mlflow configuration +We have control over our tracking configuration by setting environment variables. First, let's look at the tracking URI. For the purpose of this demo we will be working with a local database, in a similar fashion as described [here](https://mlflow.org/docs/latest/tracking/tutorials/local-database.html), +```bash +export MLFLOW_TRACKING_URI=sqlite:///mlruns.db +``` + +Qadence can also read the following two environment variables to define the mlflow experiment name and run name +```bash +export MLFLOW_EXPERIMENT=test_experiment +export MLFLOW_RUN_NAME=run_0 +``` + +If no tracking URI is provided, mlflow stores run information and artifacts in the local `./mlflow` directory and if no names are defined, the experiment and run will be named with random UUIDs. diff --git a/docs/tutorials/qml/ml_tools/trainer.md b/docs/tutorials/qml/ml_tools/trainer.md new file mode 100644 index 000000000..a7b3484b4 --- /dev/null +++ b/docs/tutorials/qml/ml_tools/trainer.md @@ -0,0 +1,676 @@ + +# Qadence Trainer Guide + +The [`Trainer`][qadence.ml_tools.Trainer] class in `qadence.ml_tools` is a versatile tool designed to streamline the training of quantum machine learning models. +It offers flexibility for both gradient-based and gradient-free optimization methods, supports custom loss functions, and integrates seamlessly with tracking tools like TensorBoard and MLflow. +Additionally, it provides hooks for implementing custom behaviors during the training process. + +For training QML models, Qadence offers this out-of-the-box [`Trainer`][qadence.ml_tools.Trainer] for optimizing differentiable +models, _e.g._ `QNN`s and `QuantumModel`, containing either *trainable* and/or *non-trainable* parameters +(see [the parameters tutorial](../../../content/parameters.md) for detailed information about parameter types): + +--- + +## 1. Overview + +The `Trainer` class simplifies the training workflow by managing the training loop, handling data loading, and facilitating model evaluation. +It is compatible with various optimization strategies and allows for extensive customization to meet specific training requirements. + +Example of initializing the `Trainer`: + +```python +from qadence.ml_tools import Trainer, TrainConfig +from torch.optim import Adam + +# Initialize model and optimizer +model = ... # Define or load a quantum model here +optimizer = Adam(model.parameters(), lr=0.01) +config = TrainConfig(max_iter=100, print_every=10) + +# Initialize Trainer with model, optimizer, and configuration +trainer = Trainer(model=model, optimizer=optimizer, config=config) +``` + +> Notes: +> `qadence` versions prior to 1.9.0 provided `train_with_grad` and `train_no_grad` functions, which are being replaced with `Trainer`. The user can transition as following. +> ```python +> from qadence.ml_tools import train_with_grad +> train_with_grad(model=model, optimizer=optimizer, config=config, data = data) +> ``` +> to +> ```python +> from qadence.ml_tools import Trainer +> trainer = Trainer(model=model, optimizer=optimizer, config=config) +> trainer.fit(train_dataloader = data) +> ``` + +## 2. Gradient-Based and Gradient-Free Optimization + +The `Trainer` supports both gradient-based and gradient-free optimization methods. +Default is gradient-based optimization. + +- **Gradient-Based Optimization**: Utilizes optimizers from PyTorch's `torch.optim` module. +This is the default behaviour of the `Trainer`, thus setting this is not necessary. +However, it can be explicity mentioned as follows. +Example of using gradient-based optimization: + +```python exec="on" source="material-block" +from qadence.ml_tools import Trainer + +# set_use_grad(True) to enable gradient based training. This is the default behaviour of Trainer. +Trainer.set_use_grad(True) +``` + +- **Gradient-Free Optimization**: Employs optimization algorithms from the [Nevergrad](https://facebookresearch.github.io/nevergrad/) library. + + +Example of using gradient-free optimization with Nevergrad: + +```python exec="on" source="material-block" +from qadence.ml_tools import Trainer + +# set_use_grad(False) to disable gradient based training. +Trainer.set_use_grad(False) +``` + +### Using Context Managers for Mixed Optimization + +For cases requiring both optimization methods in a single training session, the `Trainer` class provides context managers to enable or disable gradients. + +```python +# Temporarily switch to gradient-based optimization +with trainer.enable_grad_opt(optimizer): + print("Gradient Based Optimization") + # trainer.fit(train_loader) + +# Switch to gradient-free optimization for specific steps +with trainer.disable_grad_opt(ng_optimizer): + print("Gradient Free Optimization") + # trainer.fit(train_loader) +``` + +--- + +## 3. Custom Loss Functions + +Users can define custom loss functions tailored to their specific tasks. +The `Trainer` accepts a `loss_fn` parameter, which should be a callable that takes the model and data as inputs and returns a tuple containing the loss tensor and a dictionary of metrics. + +Example of using a custom loss function: + +```python exec="on" source="material-block" +import torch +from itertools import count +cnt = count() +criterion = torch.nn.MSELoss() + +def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: + next(cnt) + x, y = data + out = model(x) + loss = criterion(out, y) + return loss, {} +``` + +This custom loss function can be used in the trainer +```python +from qadence.ml_tools import Trainer, TrainConfig +from torch.optim import Adam + +# Initialize model and optimizer +model = ... # Define or load a quantum model here +optimizer = Adam(model.parameters(), lr=0.01) +config = TrainConfig(max_iter=100, print_every=10) + +trainer = Trainer(model=model, optimizer=optimizer, config=config, loss_fn=loss_fn) +``` + + +--- + +## 4. Hooks for Custom Behavior + +The `Trainer` class provides several hooks that enable users to inject custom behavior at different stages of the training process. +These hooks are methods that can be overridden in a subclass to execute custom code. +The available hooks include: + +- `on_train_start`: Called at the beginning of the training process. +- `on_train_end`: Called at the end of the training process. +- `on_train_epoch_start`: Called at the start of each training epoch. +- `on_train_epoch_end`: Called at the end of each training epoch. +- `on_train_batch_start`: Called at the start of each training batch. +- `on_train_batch_end`: Called at the end of each training batch. + +Each "start" and "end" hook receives data and loss metrics as arguments. The specific values provided for these arguments depend on the training stage associated with the hook. The context of the training stage (e.g., training, validation, or testing) determines which metrics are relevant and how they are populated. For details of inputs on each hook, please review the documentation of [`BaseTrainer`][qadence.ml_tools.train_utils.BaseTrainer]. + + - Example of what inputs are provided to training hooks. + + ``` + def on_train_batch_start(self, batch: Tuple[torch.Tensor, ...] | None) -> None: + """ + Called at the start of each training batch. + + Args: + batch: A batch of data from the DataLoader. Typically a tuple containing + input tensors and corresponding target tensors. + """ + pass + ``` + ``` + def on_train_batch_end(self, train_batch_loss_metrics: Tuple[torch.Tensor, Any]) -> None: + """ + Called at the end of each training batch. + + Args: + train_batch_loss_metrics: Metrics for the training batch loss. + Tuple of (loss, metrics) + """ + pass + ``` + +Example of using a hook to log a message at the end of each epoch: + +```python exec="on" source="material-block" +from qadence.ml_tools import Trainer + +class CustomTrainer(Trainer): + def on_train_epoch_end(self, train_epoch_loss_metrics): + print(f"End of epoch - Loss and Metrics: {train_epoch_loss_metrics}") +``` + +> Notes: +> Trainer offers inbuilt callbacks as well. Callbacks are mainly for logging/tracking purposes, but the above mentioned hooks are generic. The workflow for every train batch looks like: +> 1. perform on_train_batch_start callbacks, +> 2. call the on_train_batch_start hook, +> 3. do the batch training, +> 4. call the on_train_batch_end hook, and +> 5. perform on_train_batch_end callbacks. +> +> The use of `on_`*{phase}*`_start` and `on_`*{phase}*`_end` hooks is not specifically to add extra callbacks, but for any other generic pre/post processing. For example, reshaping input batch in case of RNNs/LSTMs, post processing loss and adding an extra metric. They could also be used to add more callbacks (which is not recommended - as we provide methods to add extra callbacks in the TrainCofig) + +--- + +## 5. Experiment Tracking with TensorBoard and MLflow + +The `Trainer` integrates with TensorBoard and MLflow for experiment tracking: + +- **TensorBoard**: Logs metrics and visualizations during training, allowing users to monitor the training process. + +- **MLflow**: Tracks experiments, logs parameters, metrics, and artifacts, and provides a user-friendly interface for comparing different runs. + +To utilize these tracking tools, the `Trainer` can be configured with appropriate writers that handle the logging of metrics and other relevant information during training. + +Example of using TensorBoard tracking: + +```python +from qadence.ml_tools import TrainConfig +from qadence.types import ExperimentTrackingTool + +# Set up tracking with TensorBoard +config = TrainConfig(max_iter=100, tracking_tool=ExperimentTrackingTool.TENSORBOARD) +``` + +Example of using MLflow tracking: + +```python +from qadence.types import ExperimentTrackingTool + +# Set up tracking with MLflow +config = TrainConfig(max_iter=100, tracking_tool=ExperimentTrackingTool.MLFLOW) +``` + +## 6. Examples + +### 6.1. Training with `Trainer` and `TrainConfig` + +#### Setup +Let's do the necessary imports and declare a `DataLoader`. We can already define some hyperparameters here, including the seed for random number generators. mlflow can log hyperparameters with arbitrary types, for example the observable that we want to monitor (`Z` in this case, which has a `qadence.Operation` type). + +```python +import random +from itertools import count + +import numpy as np +import torch +from matplotlib import pyplot as plt +from matplotlib.figure import Figure +from torch.nn import Module +from torch.utils.data import DataLoader + +from qadence import hea, QuantumCircuit, Z +from qadence.constructors import feature_map, hamiltonian_factory +from qadence.ml_tools import Trainer, TrainConfig +from qadence.ml_tools.data import to_dataloader +from qadence.ml_tools.utils import rand_featureparameters +from qadence.ml_tools.models import QNN, QuantumModel +from qadence.types import ExperimentTrackingTool + +hyperparams = { + "seed": 42, + "batch_size": 10, + "n_qubits": 2, + "ansatz_depth": 1, + "observable": Z, +} + +np.random.seed(hyperparams["seed"]) +torch.manual_seed(hyperparams["seed"]) +random.seed(hyperparams["seed"]) + + +def dataloader(batch_size: int = 25) -> DataLoader: + x = torch.linspace(0, 1, batch_size).reshape(-1, 1) + y = torch.cos(x) + return to_dataloader(x, y, batch_size=batch_size, infinite=True) +``` + +We continue with the regular QNN definition, together with the loss function and optimizer. + +```python +obs = hamiltonian_factory(register=hyperparams["n_qubits"], detuning=hyperparams["observable"]) + +data = dataloader(hyperparams["batch_size"]) +fm = feature_map(hyperparams["n_qubits"], param="x") + +model = QNN( + QuantumCircuit( + hyperparams["n_qubits"], fm, hea(hyperparams["n_qubits"], hyperparams["ansatz_depth"]) + ), + observable=obs, + inputs=["x"], +) + +cnt = count() +criterion = torch.nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + +inputs = rand_featureparameters(model, 1) + +def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict]: + next(cnt) + out = model.expectation(inputs) + loss = criterion(out, torch.rand(1)) + return loss, {} +``` + +#### `TrainConfig` specifications +Qadence offers different tracking options via `TrainConfig`. Here we use the `ExperimentTrackingTool` type to specify that we want to track the experiment with mlflow. Tracking with tensorboard is also possible. We can then indicate *what* and *how often* we want to track or log. + +**For Training** +`write_every` controls the number of epochs after which the loss values is logged. Thanks to the `plotting_functions` and `plot_every`arguments, we are also able to plot model-related quantities throughout training. Notice that arbitrary plotting functions can be passed, as long as the signature is the same as `plot_fn` below. Finally, the trained model can be logged by setting `log_model=True`. Here is an example of plotting function and training configuration + +```python +def plot_fn(model: Module, iteration: int) -> tuple[str, Figure]: + descr = f"ufa_prediction_epoch_{iteration}.png" + fig, ax = plt.subplots() + x = torch.linspace(0, 1, 100).reshape(-1, 1) + out = model.expectation(x) + ax.plot(x.detach().numpy(), out.detach().numpy()) + return descr, fig + + +config = TrainConfig( + root_folder="mlflow_demonstration", + max_iter=10, + checkpoint_every=1, + plot_every=2, + write_every=1, + log_model=True, + tracking_tool=ExperimentTrackingTool.MLFLOW, + hyperparams=hyperparams, + plotting_functions=(plot_fn,), +) +``` + +#### Training and inspecting +Model training happens as usual +```python +trainer = Trainer(model, optimizer, config, loss_fn) +trainer.fit(train_dataloader=data) +``` + +After training , we can inspect our experiment via the mlflow UI +```bash +mlflow ui --port 8080 --backend-store-uri sqlite:///mlruns.db +``` +In this case, since we're running on a local server, we can access the mlflow UI by navigating to http://localhost:8080/. + + +### 6.2. Fitting a function with a QNN using `ml_tools` + +In Quantum Machine Learning, the general consensus is to use `complex128` precision for states and operators and `float64` precision for parameters. This is also the convention which is used in `qadence`. +However, for specific usecases, lower precision can greatly speed up training and reduce memory consumption. When using the `pyqtorch` backend, `qadence` offers the option to move a `QuantumModel` instance to a specific precision using the torch `to` syntax. + +Let's look at a complete example of how to use `Trainer` now. Here we perform a validation check during training and use a validation criterion that checks whether the validation loss in the current iteration has decreased compared to the lowest validation loss from all previous iterations. For demonstration, the train and the validation data are kept the same here. However, it is beneficial and encouraged to keep them distinct in practice to understand model's generalization capabilities. + +```python exec="on" source="material-block" html="1" +from pathlib import Path +import torch +from functools import reduce +from operator import add +from itertools import count +import matplotlib.pyplot as plt + +from qadence import Parameter, QuantumCircuit, Z +from qadence import hamiltonian_factory, hea, feature_map, chain +from qadence import QNN +from qadence.ml_tools import TrainConfig, Trainer, to_dataloader + +Trainer.set_use_grad(True) + +DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') +DTYPE = torch.complex64 +n_qubits = 4 +fm = feature_map(n_qubits) +ansatz = hea(n_qubits=n_qubits, depth=3) +observable = hamiltonian_factory(n_qubits, detuning=Z) +circuit = QuantumCircuit(n_qubits, fm, ansatz) + +model = QNN(circuit, observable, backend="pyqtorch", diff_mode="ad") +batch_size = 100 +input_values = {"phi": torch.rand(batch_size, requires_grad=True)} +pred = model(input_values) + +cnt = count() +criterion = torch.nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + +def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: + next(cnt) + x, y = data[0], data[1] + out = model(x) + loss = criterion(out, y) + return loss, {} + +def validation_criterion( + current_validation_loss: float, current_best_validation_loss: float, val_epsilon: float +) -> bool: + return current_validation_loss <= current_best_validation_loss - val_epsilon + +n_epochs = 300 + +config = TrainConfig( + max_iter=n_epochs, + batch_size=batch_size, + checkpoint_best_only=True, + val_every=10, # The model will be run on the validation data after every `val_every` epochs. + validation_criterion=validation_criterion +) + +fn = lambda x, degree: .05 * reduce(add, (torch.cos(i*x) + torch.sin(i*x) for i in range(degree)), 0.) +x = torch.linspace(0, 10, batch_size, dtype=torch.float32).reshape(-1, 1) +y = fn(x, 5) + +train_dataloader = to_dataloader(x, y, batch_size=batch_size, infinite=True) +val_dataloader = to_dataloader(x, y, batch_size=batch_size, infinite=True) + +trainer = Trainer(model, optimizer, config, loss_fn=loss_fn, + train_dataloader = train_dataloader, val_dataloader = val_dataloader, + device=DEVICE, dtype=DTYPE) +trainer.fit() + +plt.clf() +plt.plot(x.numpy(), y.numpy(), label='truth') +plt.plot(x.numpy(), model(x).detach().numpy(), "--", label="final", linewidth=3) +plt.legend() +from docs import docsutils # markdown-exec: hide +print(docsutils.fig_to_html(plt.gcf())) # markdown-exec: hide +``` + + +### 6.3. Fitting a function - Low-level API + +For users who want to use the low-level API of `qadence`, here an example written without `Trainer`. + +```python exec="on" source="material-block" +from pathlib import Path +import torch +from itertools import count +from qadence.constructors import hamiltonian_factory, hea, feature_map +from qadence import chain, Parameter, QuantumCircuit, Z +from qadence import QNN +from qadence.ml_tools import TrainConfig + +n_qubits = 2 +fm = feature_map(n_qubits) +ansatz = hea(n_qubits=n_qubits, depth=3) +observable = hamiltonian_factory(n_qubits, detuning=Z) +circuit = QuantumCircuit(n_qubits, fm, ansatz) + +model = QNN(circuit, observable, backend="pyqtorch", diff_mode="ad") +batch_size = 1 +input_values = {"phi": torch.rand(batch_size, requires_grad=True)} +pred = model(input_values) + +criterion = torch.nn.MSELoss() +optimizer = torch.optim.Adam(model.parameters(), lr=0.1) +n_epochs=50 +cnt = count() + +tmp_path = Path("/tmp") + +config = TrainConfig( + root_folder=tmp_path, + max_iter=n_epochs, + checkpoint_every=100, + write_every=100, + batch_size=batch_size, +) + +x = torch.linspace(0, 1, batch_size).reshape(-1, 1) +y = torch.sin(x) + +for i in range(n_epochs): + out = model(x) + loss = criterion(out, y) + loss.backward() + optimizer.step() +``` + + + +### 6.4. Custom `train` loop + +If you need custom training functionality that goes beyond what is available in +`qadence.ml_tools.Trainer` you can write your own +training loop based on the building blocks that are available in Qadence. + +A simplified version of Qadence's train loop is defined below. Feel free to copy it and modify at +will. + +For logging we can use the `get_writer` from the `Writer Registry`. This will set up the default writer based on the experiment tracking tool. +All writers from the `Writer Registry` offer `open`, `close`, `print_metrics`, `write_metrics`, `plot_metrics`, etc methods. + + +```python +from typing import Callable, Union + +from torch.nn import Module +from torch.optim import Optimizer +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from qadence.ml_tools.config import TrainConfig +from qadence.ml_tools.data import DictDataLoader, data_to_device +from qadence.ml_tools.optimize_step import optimize_step +from qadence.ml_tools.callbacks import get_writer +from qadence.ml_tools.callbacks.saveload import load_checkpoint, write_checkpoint + + +def train( + model: Module, + data: DataLoader, + optimizer: Optimizer, + config: TrainConfig, + loss_fn: Callable, + device: str = "cpu", + optimize_step: Callable = optimize_step, + write_tensorboard: Callable = write_tensorboard, +) -> tuple[Module, Optimizer]: + + # Move model to device before optimizer is loaded + model = model.to(device) + + # load available checkpoint + init_iter = 0 + if config.log_folder: + model, optimizer, init_iter = load_checkpoint(config.log_folder, model, optimizer) + + # Initialize writer based on the tracking tool specified in the configuration + writer = get_writer(config.tracking_tool) # Uses ExperimentTrackingTool to select writer + writer.open(config, iteration=init_iter) + + dl_iter = iter(dataloader) + + # outer epoch loop + for iteration in range(init_iter, init_iter + config.max_iter): + data = data_to_device(next(dl_iter), device) + loss, metrics = optimize_step(model, optimizer, loss_fn, data) + + if iteration % config.print_every == 0 and config.verbose: + writer.print_metrics(OptimizeResult(iteration, model, optimizer, loss, metrics)) + + if iteration % config.write_every == 0: + writer.write(OptimizeResult(iteration, model, optimizer, loss, metrics)) + + if config.log_folder: + if iteration % config.checkpoint_every == 0: + write_checkpoint(config.log_folder, model, optimizer, iteration) + + # Final writing and checkpointing + if config.log_folder: + write_checkpoint(config.log_folder, model, optimizer, iteration) + writer.write(OptimizeResult(iteration, model, optimizer, loss, metrics)) + writer.close() + + return model, optimizer +``` + +### 6.5. Gradient-free optimization using `Trainer` + +Solving a QUBO using gradient free optimization based on `Nevergrad` optimizers and `Trainer`. This problem is further defined in [QUBO Tutorial](../../digital_analog_qc/analog-qubo.md) + + +We can achieve gradient free optimization by. +``` +Trainer.set_use_grad(False) + +# or + +trainer.disable_grad_opt(ng_optimizer): + print("Gradient free opt") +``` + + +```python exec="on" source="material-block" session="qubo" +import numpy as np +import numpy.typing as npt +from scipy.optimize import minimize +from scipy.spatial.distance import pdist, squareform +from qadence import RydbergDevice + +import torch +from qadence import QuantumModel, QuantumCircuit, Register +from qadence import RydbergDevice, AnalogRX, AnalogRZ, chain +from qadence.ml_tools import Trainer, TrainConfig, num_parameters +import nevergrad as ng +import matplotlib.pyplot as plt + +Trainer.set_use_grad(False) + +seed = 0 +np.random.seed(seed) +torch.manual_seed(seed) + +def qubo_register_coords(Q: np.ndarray, device: RydbergDevice) -> list: + """Compute coordinates for register.""" + + def evaluate_mapping(new_coords, *args): + """Cost function to minimize. Ideally, the pairwise + distances are conserved""" + Q, shape = args + new_coords = np.reshape(new_coords, shape) + interaction_coeff = device.rydberg_level + new_Q = squareform(interaction_coeff / pdist(new_coords) ** 6) + return np.linalg.norm(new_Q - Q) + + shape = (len(Q), 2) + np.random.seed(0) + x0 = np.random.random(shape).flatten() + res = minimize( + evaluate_mapping, + x0, + args=(Q, shape), + method="Nelder-Mead", + tol=1e-6, + options={"maxiter": 200000, "maxfev": None}, + ) + return [(x, y) for (x, y) in np.reshape(res.x, (len(Q), 2))] + + +# QUBO problem weights (real-value symmetric matrix) +Q = np.array([ + [-10.0, 19.7365809, 19.7365809, 5.42015853, 5.42015853], + [19.7365809, -10.0, 20.67626392, 0.17675796, 0.85604541], + [19.7365809, 20.67626392, -10.0, 0.85604541, 0.17675796], + [5.42015853, 0.17675796, 0.85604541, -10.0, 0.32306662], + [5.42015853, 0.85604541, 0.17675796, 0.32306662, -10.0], + ]) + +# Device specification and atomic register +device = RydbergDevice(rydberg_level=70) + +reg = Register.from_coordinates( +qubo_register_coords(Q, device), device_specs=device) + +# Analog variational quantum circuit +layers = 2 +block = chain(*[AnalogRX(f"t{i}") * AnalogRZ(f"s{i}") for i in range(layers)]) +circuit = QuantumCircuit(reg, block) + +model = QuantumModel(circuit) +initial_counts = model.sample({}, n_shots=1000)[0] + +print(f"initial_counts = {initial_counts}") # markdown-exec: hide + +def loss(model: QuantumModel, *args) -> tuple[torch.Tensor, dict]: + to_arr_fn = lambda bitstring: np.array(list(bitstring), dtype=int) + cost_fn = lambda arr: arr.T @ Q @ arr + samples = model.sample({}, n_shots=1000)[0] # extract samples + cost_fn = sum(samples[key] * cost_fn(to_arr_fn(key)) for key in samples) + return torch.tensor(cost_fn / sum(samples.values())), {} # We return an optional metrics dict + + + +# Training +config = TrainConfig(max_iter=100) +optimizer = ng.optimizers.NGOpt( + budget=config.max_iter, parametrization=num_parameters(model) + ) +trainer = Trainer(model, optimizer, config, loss) +trainer.fit() + +optimal_counts = model.sample({}, n_shots=1000)[0] +print(f"optimal_count = {optimal_counts}") # markdown-exec: hide + + +# Known solutions to the QUBO problem. +solution_bitstrings = ["01011", "00111"] + +def plot_distribution(C, ax, title): + C = dict(sorted(C.items(), key=lambda item: item[1], reverse=True)) + indexes = solution_bitstrings # QUBO solutions + color_dict = {key: "r" if key in indexes else "g" for key in C} + ax.set_xlabel("bitstrings") + ax.set_ylabel("counts") + ax.set_xticks([i for i in range(len(C.keys()))], C.keys(), rotation=90) + ax.bar(list(C.keys())[:20], list(C.values())[:20]) + ax.set_title(title) + +plt.tight_layout() # markdown-exec: hide +fig, axs = plt.subplots(1, 2, figsize=(12, 4)) +plot_distribution(initial_counts, axs[0], "Initial counts") +plot_distribution(optimal_counts, axs[1], "Optimal counts") +from docs import docsutils # markdown-exec: hide +print(docsutils.fig_to_html(fig)) # markdown-exec: hide +``` diff --git a/docs/tutorials/qml/qaoa.md b/docs/tutorials/qml/qaoa.md index 31570c091..96c34431b 100644 --- a/docs/tutorials/qml/qaoa.md +++ b/docs/tutorials/qml/qaoa.md @@ -195,7 +195,7 @@ for i in range(n_epochs): ``` Qadence offers some convenience functions to implement this training loop with advanced -logging and metrics track features. You can refer to [this tutorial](ml_tools.md) for more details. +logging and metrics track features. You can refer to [this tutorial](ml_tools/trainer.md) for more details. ## Results diff --git a/docs/tutorials/qml/qcl.md b/docs/tutorials/qml/qcl.md index 4959bfb02..6b4eb46b9 100644 --- a/docs/tutorials/qml/qcl.md +++ b/docs/tutorials/qml/qcl.md @@ -114,7 +114,7 @@ assert loss.item() < 1e-3 ``` Qadence offers some convenience functions to implement this training loop with advanced -logging and metrics track features. You can refer to [this tutorial](ml_tools.md) for more details. +logging and metrics track features. You can refer to [this tutorial](ml_tools/trainer.md) for more details. The quantum model is now trained on the training data points. To determine the quality of the results, one can check to see how well it fits the function on the test set. diff --git a/examples/quick_start.py b/examples/quick_start.py index 09b9bee17..206db73b2 100644 --- a/examples/quick_start.py +++ b/examples/quick_start.py @@ -8,7 +8,7 @@ from qadence.blocks import kron # block system from qadence.circuit import QuantumCircuit # circuit to assemble quantum operations from qadence.logger import get_script_logger # Extend Qadence logging to your scripts -from qadence.ml_tools import TrainConfig, train_with_grad # tools for ML simulations +from qadence.ml_tools import TrainConfig, Trainer # tools for ML simulations from qadence.operations import RX, HamEvo, X, Y, Zero # quantum operations from qadence.parameters import VariationalParameter # trainable parameters @@ -57,4 +57,5 @@ def loss_fn(model_: QuantumModel, _): optimizer = torch.optim.Adam(model.parameters(), lr=0.1) config = TrainConfig(max_iter=100, checkpoint_every=10, print_every=10) -train_with_grad(model, None, optimizer, config, loss_fn=loss_fn) +trainer = Trainer(model, optimizer, config, loss_fn) +model, optimizer = trainer.fit() diff --git a/mkdocs.yml b/mkdocs.yml index 378e784cb..60cd51f95 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -42,7 +42,9 @@ nav: - Variational quantum algorithms: - tutorials/qml/index.md - - Training tools: tutorials/qml/ml_tools.md + - Training: tutorials/qml/ml_tools/trainer.md + - Training Callbacks: tutorials/qml/ml_tools/callbacks.md + - Data and Configurations: tutorials/qml/ml_tools/data_and_config.md - Configuring a QNN: tutorials/qml/config_qnn.md - Quantum circuit learning: tutorials/qml/qcl.md - Solving MaxCut with QAOA: tutorials/qml/qaoa.md diff --git a/qadence/__init__.py b/qadence/__init__.py index b97ed79e8..2d83651c3 100644 --- a/qadence/__init__.py +++ b/qadence/__init__.py @@ -34,7 +34,7 @@ [ h.setLevel(LOG_LEVEL) # type: ignore[func-returns-value] for h in logger.handlers - if h.get_name() == "console" + if h.get_name() == "console" or h.get_name() == "richconsole" ] logger.debug(f"Qadence logger successfully setup with log level {LOG_LEVEL}") diff --git a/qadence/log_config.yaml b/qadence/log_config.yaml index be57eeb60..f441f5682 100644 --- a/qadence/log_config.yaml +++ b/qadence/log_config.yaml @@ -4,11 +4,17 @@ formatters: base: format: "%(levelname) -5s %(asctime)s - %(name)s: %(message)s" datefmt: "%Y-%m-%d %H:%M:%S" + empty: + format: "%(message)s" # Rich formatter for cleaner output + datefmt: "%Y-%m-%d %H:%M:%S" handlers: console: class: logging.StreamHandler formatter: base stream: ext://sys.stderr + richconsole: + class: rich.logging.RichHandler + formatter: empty loggers: qadence: level: INFO @@ -22,3 +28,7 @@ loggers: level: INFO handlers: [console] propagate: yes + ml_tools: + level: INFO + handlers: [richconsole] + propagate: false diff --git a/qadence/ml_tools/__init__.py b/qadence/ml_tools/__init__.py index 9a3a980f2..4a4334c32 100644 --- a/qadence/ml_tools/__init__.py +++ b/qadence/ml_tools/__init__.py @@ -1,16 +1,14 @@ from __future__ import annotations -from .config import AnsatzConfig, Callback, FeatureMapConfig, TrainConfig +from .callbacks.saveload import load_checkpoint, load_model, write_checkpoint +from .config import AnsatzConfig, FeatureMapConfig, TrainConfig from .constructors import create_ansatz, create_fm_blocks, observable_from_config from .data import DictDataLoader, InfiniteTensorDataset, OptimizeResult, to_dataloader from .models import QNN from .optimize_step import optimize_step as default_optimize_step from .parameters import get_parameters, num_parameters, set_parameters -from .printing import print_metrics, write_tensorboard -from .saveload import load_checkpoint, load_model, write_checkpoint from .tensors import numpy_to_tensor, promote_to, promote_to_tensor -from .train_grad import train as train_with_grad -from .train_no_grad import train as train_gradient_free +from .trainer import Trainer # Modules to be automatically added to the qadence namespace __all__ = [ @@ -24,8 +22,6 @@ "QNN", "TrainConfig", "OptimizeResult", - "Callback", - "train_with_grad", - "train_gradient_free", + "Trainer", "write_checkpoint", ] diff --git a/qadence/ml_tools/callbacks/__init__.py b/qadence/ml_tools/callbacks/__init__.py new file mode 100644 index 000000000..2eca22712 --- /dev/null +++ b/qadence/ml_tools/callbacks/__init__.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from .callback import ( + Callback, + LoadCheckpoint, + LogHyperparameters, + LogModelTracker, + PlotMetrics, + PrintMetrics, + SaveBestCheckpoint, + SaveCheckpoint, + WriteMetrics, +) +from .callbackmanager import CallbacksManager +from .writer_registry import get_writer + +# Modules to be automatically added to the qadence.ml_tools.callbacks namespace +__all__ = [ + "CallbacksManager", + "Callback", + "LoadCheckpoint", + "LogHyperparameters", + "LogModelTracker", + "PlotMetrics", + "PrintMetrics", + "SaveBestCheckpoint", + "SaveCheckpoint", + "WriteMetrics", + "get_writer", +] diff --git a/qadence/ml_tools/callbacks/callback.py b/qadence/ml_tools/callbacks/callback.py new file mode 100644 index 000000000..bda561304 --- /dev/null +++ b/qadence/ml_tools/callbacks/callback.py @@ -0,0 +1,451 @@ +from __future__ import annotations + +from typing import Any, Callable + +from qadence.ml_tools.callbacks.saveload import load_checkpoint, write_checkpoint +from qadence.ml_tools.callbacks.writer_registry import BaseWriter +from qadence.ml_tools.config import TrainConfig +from qadence.ml_tools.data import OptimizeResult +from qadence.ml_tools.stages import TrainingStage + +# Define callback types +CallbackFunction = Callable[..., Any] +CallbackConditionFunction = Callable[..., bool] + + +class Callback: + """Base class for defining various training callbacks. + + Attributes: + on (str): The event on which to trigger the callback. + Must be a valid on value from: ["train_start", "train_end", + "train_epoch_start", "train_epoch_end", "train_batch_start", + "train_batch_end","val_epoch_start", "val_epoch_end", + "val_batch_start", "val_batch_end", "test_batch_start", + "test_batch_end"] + called_every (int): Frequency of callback calls in terms of iterations. + callback (CallbackFunction | None): The function to call if the condition is met. + callback_condition (CallbackConditionFunction | None): Condition to check before calling. + modify_optimize_result (CallbackFunction | dict[str, Any] | None): + Function to modify `OptimizeResult`. + + A callback can be defined in two ways: + + 1. **By providing a callback function directly in the base class**: + This is useful for simple callbacks that don't require subclassing. + + Example: + ```python exec="on" source="material-block" result="json" + from qadence.ml_tools.callbacks import Callback + + def custom_callback_function(trainer, config, writer): + print("Custom callback executed.") + + custom_callback = Callback( + on="train_end", + called_every=5, + callback=custom_callback_function + ) + ``` + + 2. **By inheriting and implementing the `run_callback` method**: + This is suitable for more complex callbacks that require customization. + + Example: + ```python exec="on" source="material-block" result="json" + from qadence.ml_tools.callbacks import Callback + class CustomCallback(Callback): + def run_callback(self, trainer, config, writer): + print("Custom behavior in the inherited run_callback method.") + + custom_callback = CustomCallback(on="train_end", called_every=10) + ``` + """ + + VALID_ON_VALUES = [ + "train_start", + "train_end", + "train_epoch_start", + "train_epoch_end", + "train_batch_start", + "train_batch_end", + "val_epoch_start", + "val_epoch_end", + "val_batch_start", + "val_batch_end", + "test_batch_start", + "test_batch_end", + ] + + def __init__( + self, + on: str | TrainingStage = "idle", + called_every: int = 1, + callback: CallbackFunction | None = None, + callback_condition: CallbackConditionFunction | None = None, + modify_optimize_result: CallbackFunction | dict[str, Any] | None = None, + ): + if not isinstance(called_every, int): + raise ValueError("called_every must be a positive integer or 0") + + self.callback: CallbackFunction | None = callback + self.on: str | TrainingStage = on + self.called_every: int = called_every + self.callback_condition = callback_condition or (lambda _: True) + + if isinstance(modify_optimize_result, dict): + self.modify_optimize_result = ( + lambda opt_res: opt_res.extra.update(modify_optimize_result) or opt_res + ) + else: + self.modify_optimize_result = modify_optimize_result or (lambda opt_res: opt_res) + + @property + def on(self) -> TrainingStage | str: + """ + Returns the TrainingStage. + + Returns: + TrainingStage: TrainingStage for the callback + """ + return self._on + + @on.setter + def on(self, on: str | TrainingStage) -> None: + """ + Sets the training stage on for the callback. + + Args: + on (str | TrainingStage): TrainingStage for the callback + """ + if isinstance(on, str): + if on not in self.VALID_ON_VALUES: + raise ValueError(f"Invalid value for 'on'. Must be one of {self.VALID_ON_VALUES}.") + self._on = TrainingStage(on) + elif isinstance(on, TrainingStage): + self._on = on + else: + raise ValueError("Invalid value for 'on'. Must be `str` or `TrainingStage`.") + + def _should_call(self, when: str, opt_result: OptimizeResult) -> bool: + """Checks if the callback should be called. + + Args: + when (str): The event when the callback is considered for execution. + opt_result (OptimizeResult): The current optimization results. + + Returns: + bool: Whether the callback should be called. + """ + if when in [TrainingStage("train_start"), TrainingStage("train_end")]: + return True + if self.called_every == 0 or opt_result.iteration == 0: + return False + if opt_result.iteration % self.called_every == 0 and self.callback_condition(opt_result): + return True + return False + + def __call__( + self, when: TrainingStage, trainer: Any, config: TrainConfig, writer: BaseWriter + ) -> Any: + """Executes the callback if conditions are met. + + Args: + when (str): The event when the callback is triggered. + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + + Returns: + Any: Result of the callback function if executed. + """ + opt_result = trainer.opt_result + if self.on == when: + if opt_result: + opt_result = self.modify_optimize_result(opt_result) + if self._should_call(when, opt_result): + return self.run_callback(trainer, config, writer) + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Executes the defined callback. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + + Returns: + Any: Result of the callback execution. + + Raises: + NotImplementedError: If not implemented in subclasses. + """ + if self.callback is not None: + return self.callback(trainer, config, writer) + raise NotImplementedError("Subclasses should override the run_callback method.") + + +class PrintMetrics(Callback): + """Callback to print metrics using the writer. + + The `PrintMetrics` callback can be added to the `TrainConfig` + callbacks as a custom user defined callback. + + Example Usage in `TrainConfig`: + To use `PrintMetrics`, include it in the `callbacks` list when + setting up your `TrainConfig`: + ```python exec="on" source="material-block" result="json" + from qadence.ml_tools import TrainConfig + from qadence.ml_tools.callbacks import PrintMetrics + + # Create an instance of the PrintMetrics callback + print_metrics_callback = PrintMetrics(on = "val_batch_end", called_every = 100) + + config = TrainConfig( + max_iter=10000, + # Print metrics every 1000 training epochs + print_every=1000, + # Add the custom callback that runs every 100 val_batch_end + callbacks=[print_metrics_callback] + ) + ``` + """ + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Prints metrics using the writer. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + """ + opt_result = trainer.opt_result + writer.print_metrics(opt_result) + + +class WriteMetrics(Callback): + """Callback to write metrics using the writer. + + The `WriteMetrics` callback can be added to the `TrainConfig` callbacks as + a custom user defined callback. + + Example Usage in `TrainConfig`: + To use `WriteMetrics`, include it in the `callbacks` list when setting up your + `TrainConfig`: + ```python exec="on" source="material-block" result="json" + from qadence.ml_tools import TrainConfig + from qadence.ml_tools.callbacks import WriteMetrics + + # Create an instance of the WriteMetrics callback + write_metrics_callback = WriteMetrics(on = "val_batch_end", called_every = 100) + + config = TrainConfig( + max_iter=10000, + # Print metrics every 1000 training epochs + print_every=1000, + # Add the custom callback that runs every 100 val_batch_end + callbacks=[write_metrics_callback] + ) + ``` + """ + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Writes metrics using the writer. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + """ + opt_result = trainer.opt_result + writer.write(opt_result) + + +class PlotMetrics(Callback): + """Callback to plot metrics using the writer. + + The `PlotMetrics` callback can be added to the `TrainConfig` callbacks as + a custom user defined callback. + + Example Usage in `TrainConfig`: + To use `PlotMetrics`, include it in the `callbacks` list when setting up your + `TrainConfig`: + ```python exec="on" source="material-block" result="json" + from qadence.ml_tools import TrainConfig + from qadence.ml_tools.callbacks import PlotMetrics + + # Create an instance of the PlotMetrics callback + plot_metrics_callback = PlotMetrics(on = "val_batch_end", called_every = 100) + + config = TrainConfig( + max_iter=10000, + # Print metrics every 1000 training epochs + print_every=1000, + # Add the custom callback that runs every 100 val_batch_end + callbacks=[plot_metrics_callback] + ) + ``` + """ + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Plots metrics using the writer. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + """ + opt_result = trainer.opt_result + plotting_functions = config.plotting_functions + writer.plot(trainer.model, opt_result.iteration, plotting_functions) + + +class LogHyperparameters(Callback): + """Callback to log hyperparameters using the writer. + + The `LogHyperparameters` callback can be added to the `TrainConfig` callbacks + as a custom user defined callback. + + Example Usage in `TrainConfig`: + To use `LogHyperparameters`, include it in the `callbacks` list when setting up your + `TrainConfig`: + ```python exec="on" source="material-block" result="json" + from qadence.ml_tools import TrainConfig + from qadence.ml_tools.callbacks import LogHyperparameters + + # Create an instance of the LogHyperparameters callback + log_hyper_callback = LogHyperparameters(on = "val_batch_end", called_every = 100) + + config = TrainConfig( + max_iter=10000, + # Print metrics every 1000 training epochs + print_every=1000, + # Add the custom callback that runs every 100 val_batch_end + callbacks=[log_hyper_callback] + ) + ``` + """ + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Logs hyperparameters using the writer. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + """ + hyperparams = config.hyperparams + writer.log_hyperparams(hyperparams) + + +class SaveCheckpoint(Callback): + """Callback to save a model checkpoint. + + The `SaveCheckpoint` callback can be added to the `TrainConfig` callbacks + as a custom user defined callback. + + Example Usage in `TrainConfig`: + To use `SaveCheckpoint`, include it in the `callbacks` list when setting up your + `TrainConfig`: + ```python exec="on" source="material-block" result="json" + from qadence.ml_tools import TrainConfig + from qadence.ml_tools.callbacks import SaveCheckpoint + + # Create an instance of the SaveCheckpoint callback + save_checkpoint_callback = SaveCheckpoint(on = "val_batch_end", called_every = 100) + + config = TrainConfig( + max_iter=10000, + # Print metrics every 1000 training epochs + print_every=1000, + # Add the custom callback that runs every 100 val_batch_end + callbacks=[save_checkpoint_callback] + ) + ``` + """ + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Saves a model checkpoint. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + """ + folder = config.log_folder + model = trainer.model + optimizer = trainer.optimizer + opt_result = trainer.opt_result + write_checkpoint(folder, model, optimizer, opt_result.iteration) + + +class SaveBestCheckpoint(SaveCheckpoint): + """Callback to save the best model checkpoint based on a validation criterion.""" + + def __init__(self, on: str, called_every: int): + """Initializes the SaveBestCheckpoint callback. + + Args: + on (str): The event to trigger the callback. + called_every (int): Frequency of callback calls in terms of iterations. + """ + super().__init__(on=on, called_every=called_every) + self.best_loss = float("inf") + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Saves the checkpoint if the current loss is better than the best loss. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + """ + opt_result = trainer.opt_result + if config.validation_criterion and config.validation_criterion( + opt_result.loss, self.best_loss, config.val_epsilon + ): + self.best_loss = opt_result.loss + + folder = config.log_folder + model = trainer.model + optimizer = trainer.optimizer + opt_result = trainer.opt_result + write_checkpoint(folder, model, optimizer, "best") + + +class LoadCheckpoint(Callback): + """Callback to load a model checkpoint.""" + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Loads a model checkpoint. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + + Returns: + Any: The result of loading the checkpoint. + """ + folder = config.log_folder + model = trainer.model + optimizer = trainer.optimizer + device = trainer.log_device + return load_checkpoint(folder, model, optimizer, device=device) + + +class LogModelTracker(Callback): + """Callback to log the model using the writer.""" + + def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> Any: + """Logs the model using the writer. + + Args: + trainer (Any): The training object. + config (TrainConfig): The configuration object. + writer (BaseWriter ): The writer object for logging. + """ + model = trainer.model + writer.log_model( + model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader + ) diff --git a/qadence/ml_tools/callbacks/callbackmanager.py b/qadence/ml_tools/callbacks/callbackmanager.py new file mode 100644 index 000000000..a16621b44 --- /dev/null +++ b/qadence/ml_tools/callbacks/callbackmanager.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +import copy +import logging +from typing import Any + +from qadence.ml_tools.callbacks.callback import ( + Callback, + LoadCheckpoint, + LogHyperparameters, + LogModelTracker, + PlotMetrics, + PrintMetrics, + SaveBestCheckpoint, + SaveCheckpoint, + WriteMetrics, +) +from qadence.ml_tools.config import TrainConfig +from qadence.ml_tools.data import OptimizeResult +from qadence.ml_tools.stages import TrainingStage + +from .writer_registry import get_writer + +logger = logging.getLogger("ml_tools") + + +class CallbacksManager: + """Manages and orchestrates the execution of various training callbacks. + + Provides the start training and end training methods. + + Attributes: + use_grad (bool): Indicates whether to use gradients in callbacks. + config (TrainConfig): The training configuration object. + callbacks (List[Callback]): List of callback instances to be executed. + writer (Optional[BaseWriter]): The writer instance for logging metrics and information. + """ + + use_grad: bool = True + + callback_map = { + "PrintMetrics": PrintMetrics, + "WriteMetrics": WriteMetrics, + "PlotMetrics": PlotMetrics, + "SaveCheckpoint": SaveCheckpoint, + "LoadCheckpoint": LoadCheckpoint, + "LogModelTracker": LogModelTracker, + "LogHyperparameters": LogHyperparameters, + "SaveBestCheckpoint": SaveBestCheckpoint, + } + + def __init__(self, config: TrainConfig): + """ + Initializes the CallbacksManager with a training configuration. + + Args: + config (TrainConfig): The training configuration object. + """ + self.config = config + tracking_tool = self.config.tracking_tool + self.writer = get_writer(tracking_tool) + self.callbacks: list[Callback] = [] + + @classmethod + def set_use_grad(cls, use_grad: bool) -> None: + """ + Sets whether gradients should be used in callbacks. + + Args: + use_grad (bool): A boolean indicating whether to use gradients. + """ + if not isinstance(use_grad, bool): + raise ValueError("use_grad must be a boolean value.") + cls.use_grad = use_grad + + def initialize_callbacks(self) -> None: + """Initializes and adds the necessary callbacks based on the configuration.""" + # Train Start + self.callbacks = copy.deepcopy(self.config.callbacks) + self.add_callback("PlotMetrics", "train_start") + if self.config.val_every: + self.add_callback("WriteMetrics", "train_start") + # only save the first checkpoint if not checkpoint_best_only + if not self.config.checkpoint_best_only: + self.add_callback("SaveCheckpoint", "train_start") + + # Checkpointing + if self.config.checkpoint_best_only: + self.add_callback("SaveBestCheckpoint", "val_epoch_end", self.config.val_every) + elif self.config.checkpoint_every: + self.add_callback("SaveCheckpoint", "train_epoch_end", self.config.checkpoint_every) + + # Printing + if self.config.verbose and self.config.print_every: + self.add_callback("PrintMetrics", "train_epoch_end", self.config.print_every) + + # Plotting + if self.config.plot_every: + self.add_callback("PlotMetrics", "train_epoch_end", self.config.plot_every) + + # Writing + if self.config.write_every: + self.add_callback("WriteMetrics", "train_epoch_end", self.config.write_every) + if self.config.val_every: + self.add_callback("WriteMetrics", "val_epoch_end", self.config.val_every) + + # Train End + # Hyperparameters + if self.config.hyperparams: + self.add_callback("LogHyperparameters", "train_end") + # Log model + if self.config.log_model: + self.add_callback("LogModelTracker", "train_end") + if self.config.plot_every: + self.add_callback("PlotMetrics", "train_end") + # only save the last checkpoint if not checkpoint_best_only + if not self.config.checkpoint_best_only: + self.add_callback("SaveCheckpoint", "train_end") + self.add_callback("WriteMetrics", "train_end") + + def add_callback( + self, callback: str | Callback, on: str | TrainingStage, called_every: int = 1 + ) -> None: + """ + Adds a callback to the manager. + + Args: + callback (str | Callback): The callback instance or name. + on (str | TrainingStage): The event on which to trigger the callback. + called_every (int): Frequency of callback calls in terms of iterations. + """ + if isinstance(callback, str): + callback_class = self.callback_map.get(callback) + if callback_class: + # Create an instance of the callback class + callback_instance = callback_class(on=on, called_every=called_every) + self.callbacks.append(callback_instance) + else: + logger.warning(f"Callback '{callback}' not recognized and will be skipped.") + elif isinstance(callback, Callback): + callback.on = on + callback.called_every = called_every + self.callbacks.append(callback) + else: + logger.warning( + f"Invalid callback type: {type(callback)}. Expected str or Callback instance." + ) + + def run_callbacks(self, trainer: Any) -> Any: + """ + Runs callbacks that match the current training state. + + Args: + trainer (Any): The training object managing the training process. + + Returns: + Any: Results of the executed callbacks. + """ + return [ + callback( + when=trainer.training_stage, trainer=trainer, config=self.config, writer=self.writer + ) + for callback in self.callbacks + if callback.on == trainer.training_stage + ] + + def start_training(self, trainer: Any) -> None: + """ + Initializes callbacks and starts the training process. + + Args: + trainer (Any): The training object managing the training process. + """ + # Clear all handlers from the logger + self.initialize_callbacks() + + trainer.opt_result = OptimizeResult(trainer.global_step, trainer.model, trainer.optimizer) + trainer.is_last_iteration = False + + # Load checkpoint only if a new subfolder was NOT recently added + if not trainer.config_manager._added_new_subfolder: + load_checkpoint_callback = LoadCheckpoint(on="train_start", called_every=1) + loaded_result = load_checkpoint_callback.run_callback( + trainer=trainer, + config=self.config, + writer=None, # type: ignore[arg-type] + ) + + if loaded_result: + model, optimizer, init_iter = loaded_result + if isinstance(init_iter, (int, str)): + trainer.model = model + trainer.optimizer = optimizer + trainer.global_step = ( + init_iter if isinstance(init_iter, int) else trainer.global_step + ) + trainer.current_epoch = ( + init_iter if isinstance(init_iter, int) else trainer.current_epoch + ) + trainer.opt_result = OptimizeResult(trainer.current_epoch, model, optimizer) + logger.debug(f"Loaded model and optimizer from {self.config.log_folder}") + + # Setup writer + self.writer.open(self.config, iteration=trainer.global_step) + + def end_training(self, trainer: Any) -> None: + """ + Cleans up and finalizes the training process. + + Args: + trainer (Any): The training object managing the training process. + """ + if self.writer: + self.writer.close() diff --git a/qadence/ml_tools/saveload.py b/qadence/ml_tools/callbacks/saveload.py similarity index 91% rename from qadence/ml_tools/saveload.py rename to qadence/ml_tools/callbacks/saveload.py index a022cd1ec..614f4bad0 100644 --- a/qadence/ml_tools/saveload.py +++ b/qadence/ml_tools/callbacks/saveload.py @@ -11,7 +11,7 @@ from torch.nn import Module from torch.optim import Optimizer -logger = getLogger(__name__) +logger = getLogger("ml_tools") def get_latest_checkpoint_name(folder: Path, type: str, device: str | torch.device = "cpu") -> Path: @@ -19,6 +19,7 @@ def get_latest_checkpoint_name(folder: Path, type: str, device: str | torch.devi files = [f for f in os.listdir(folder) if f.endswith(".pt") and type in f] if len(files) == 0: logger.error(f"Directory {folder} does not contain any {type} checkpoints.") + pass if len(files) == 1: file = Path(files[0]) else: @@ -66,8 +67,7 @@ def write_checkpoint( iteration: int | str, ) -> None: from qadence import QuantumModel - - from .models import QNN + from qadence.ml_tools.models import QNN device = None try: @@ -79,10 +79,8 @@ def write_checkpoint( ) device = str(device).split(":")[0] # in case of using several CUDA devices except Exception as e: - msg = ( - f"Unable to identify in which device the QuantumModel is stored due to {e}." - "Setting device to None" - ) + msg = f"""Unable to identify in which device the QuantumModel is stored due to {e}. + Setting device to None""" logger.warning(msg) iteration_substring = f"{iteration:03n}" if isinstance(iteration, int) else iteration @@ -135,7 +133,9 @@ def load_model( model_ckpt_name = get_latest_checkpoint_name(folder, "model", device) try: - iteration, model_dict = torch.load(folder / model_ckpt_name, *args, **kwargs) + iteration, model_dict = torch.load( + folder / model_ckpt_name, weights_only=False, *args, **kwargs + ) if isinstance(model, (QuantumModel, QNN)): model.load_params_from_dict(model_dict) elif isinstance(model, Module): @@ -146,8 +146,8 @@ def load_model( model.to(device) except Exception as e: - msg = f"Unable to load state dict due to {e}.\ - No corresponding pre-trained model found. Returning the un-trained model." + msg = f"""Unable to load state dict due to {e}. + No corresponding pre-trained model found.""" logger.warning(msg) return model, iteration @@ -162,7 +162,7 @@ def load_optimizer( opt_ckpt_name = get_latest_checkpoint_name(folder, "opt", device) if os.path.isfile(folder / opt_ckpt_name): if isinstance(optimizer, Optimizer): - (_, OptType, optimizer_state) = torch.load(folder / opt_ckpt_name) + (_, OptType, optimizer_state) = torch.load(folder / opt_ckpt_name, weights_only=False) if isinstance(optimizer, OptType): optimizer.load_state_dict(optimizer_state) diff --git a/qadence/ml_tools/callbacks/writer_registry.py b/qadence/ml_tools/callbacks/writer_registry.py new file mode 100644 index 000000000..5bb78a582 --- /dev/null +++ b/qadence/ml_tools/callbacks/writer_registry.py @@ -0,0 +1,430 @@ +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from logging import getLogger +from types import ModuleType +from typing import Any, Callable, Union +from uuid import uuid4 + +import mlflow +from matplotlib.figure import Figure +from mlflow.entities import Run +from mlflow.models import infer_signature +from torch import Tensor +from torch.nn import Module +from torch.utils.data import DataLoader +from torch.utils.tensorboard import SummaryWriter + +from qadence.ml_tools.config import TrainConfig +from qadence.ml_tools.data import OptimizeResult +from qadence.types import ExperimentTrackingTool + +logger = getLogger("ml_tools") + +# Type aliases +PlottingFunction = Callable[[Module, int], tuple[str, Figure]] +InputData = Union[Tensor, dict[str, Tensor]] + + +class BaseWriter(ABC): + """ + Abstract base class for experiment tracking writers. + + Methods: + open(config, iteration=None): Opens the writer and sets up the logging + environment. + close(): Closes the writer and finalizes any ongoing logging processes. + print_metrics(result): Prints metrics and loss in a formatted manner. + write(result): Writes the optimization results to the tracking tool. + log_hyperparams(hyperparams): Logs the hyperparameters to the tracking tool. + plot(model, iteration, plotting_functions): Logs model plots using provided + plotting functions. + log_model(model, dataloader): Logs the model and any relevant information. + """ + + run: Run # [attr-defined] + + @abstractmethod + def open(self, config: TrainConfig, iteration: int | None = None) -> Any: + """ + Opens the writer and prepares it for logging. + + Args: + config: Configuration object containing settings for logging. + iteration (int, optional): The iteration step to start logging from. + Defaults to None. + """ + raise NotImplementedError("Writers must implement an open method.") + + @abstractmethod + def close(self) -> None: + """Closes the writer and finalizes logging.""" + raise NotImplementedError("Writers must implement a close method.") + + @abstractmethod + def write(self, result: OptimizeResult) -> None: + """ + Logs the results of the current iteration. + + Args: + result (OptimizeResult): The optimization results to log. + """ + raise NotImplementedError("Writers must implement a write method.") + + @abstractmethod + def log_hyperparams(self, hyperparams: dict) -> None: + """ + Logs hyperparameters. + + Args: + hyperparams (dict): A dictionary of hyperparameters to log. + """ + raise NotImplementedError("Writers must implement a log_hyperparams method.") + + @abstractmethod + def plot( + self, + model: Module, + iteration: int, + plotting_functions: tuple[PlottingFunction, ...], + ) -> None: + """ + Logs plots of the model using provided plotting functions. + + Args: + model (Module): The model to plot. + iteration (int): The current iteration number. + plotting_functions (tuple[PlottingFunction, ...]): Functions used to + generate plots. + """ + raise NotImplementedError("Writers must implement a plot method.") + + @abstractmethod + def log_model( + self, + model: Module, + train_dataloader: DataLoader | None = None, + val_dataloader: DataLoader | None = None, + test_dataloader: DataLoader | None = None, + ) -> None: + """ + Logs the model and associated data. + + Args: + model (Module): The model to log. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. + """ + raise NotImplementedError("Writers must implement a log_model method.") + + def print_metrics(self, result: OptimizeResult) -> None: + """Prints the metrics and loss in a readable format. + + Args: + result (OptimizeResult): The optimization results to display. + """ + + # Find the key in result.metrics that contains "loss" (case-insensitive) + loss_key = next((k for k in result.metrics if "loss" in k.lower()), None) + if loss_key: + loss_value = result.metrics[loss_key] + msg = f"Iteration {result.iteration: >7} | {loss_key.title()}: {loss_value:.7f} -" + else: + msg = f"Iteration {result.iteration: >7} | Loss: None -" + msg += " ".join([f"{k}: {v:.7f}" for k, v in result.metrics.items() if k != loss_key]) + print(msg) + + +class TensorBoardWriter(BaseWriter): + """Writer for logging to TensorBoard. + + Attributes: + writer (SummaryWriter): The TensorBoard SummaryWriter instance. + """ + + def __init__(self) -> None: + self.writer = None + + def open(self, config: TrainConfig, iteration: int | None = None) -> SummaryWriter: + """ + Opens the TensorBoard writer. + + Args: + config: Configuration object containing settings for logging. + iteration (int, optional): The iteration step to start logging from. + Defaults to None. + + Returns: + SummaryWriter: The initialized TensorBoard writer. + """ + log_dir = str(config.log_folder) + purge_step = iteration if isinstance(iteration, int) else None + self.writer = SummaryWriter(log_dir=log_dir, purge_step=purge_step) + return self.writer + + def close(self) -> None: + """Closes the TensorBoard writer.""" + if self.writer: + self.writer.close() + + def write(self, result: OptimizeResult) -> None: + """ + Logs the results of the current iteration to TensorBoard. + + Args: + result (OptimizeResult): The optimization results to log. + """ + # Not writing loss as loss is available in the metrics + # if result.loss is not None: + # self.writer.add_scalar("loss", float(result.loss), result.iteration) + if self.writer: + for key, value in result.metrics.items(): + self.writer.add_scalar(key, value, result.iteration) + else: + raise RuntimeError( + "The writer is not initialized." + "Please call the 'writer.open()' method before writing" + ) + + def log_hyperparams(self, hyperparams: dict) -> None: + """ + Logs hyperparameters to TensorBoard. + + Args: + hyperparams (dict): A dictionary of hyperparameters to log. + """ + if self.writer: + self.writer.add_hparams(hyperparams, {}) + else: + raise RuntimeError( + "The writer is not initialized." + "Please call the 'writer.open()' method before writing" + ) + + def plot( + self, + model: Module, + iteration: int, + plotting_functions: tuple[PlottingFunction, ...], + ) -> None: + """ + Logs plots of the model using provided plotting functions. + + Args: + model (Module): The model to plot. + iteration (int): The current iteration number. + plotting_functions (tuple[PlottingFunction, ...]): Functions used + to generate plots. + """ + if self.writer: + for pf in plotting_functions: + descr, fig = pf(model, iteration) + self.writer.add_figure(descr, fig, global_step=iteration) + else: + raise RuntimeError( + "The writer is not initialized." + "Please call the 'writer.open()' method before writing" + ) + + def log_model( + self, + model: Module, + train_dataloader: DataLoader | None = None, + val_dataloader: DataLoader | None = None, + test_dataloader: DataLoader | None = None, + ) -> None: + """ + Logs the model. + + Currently not supported by TensorBoard. + + Args: + model (Module): The model to log. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. + """ + logger.warning("Model logging is not supported by tensorboard. No model will be logged.") + + +class MLFlowWriter(BaseWriter): + """ + Writer for logging to MLflow. + + Attributes: + run: The active MLflow run. + mlflow: The MLflow module. + """ + + def __init__(self) -> None: + self.run: Run + self.mlflow: ModuleType + + def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType | None: + """ + Opens the MLflow writer and initializes an MLflow run. + + Args: + config: Configuration object containing settings for logging. + iteration (int, optional): The iteration step to start logging from. + Defaults to None. + + Returns: + mlflow: The MLflow module instance. + """ + self.mlflow = mlflow + tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "") + experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4())) + run_name = os.getenv("MLFLOW_RUN_NAME", str(uuid4())) + + if self.mlflow: + self.mlflow.set_tracking_uri(tracking_uri) + + # Create or get the experiment + exp_filter_string = f"name = '{experiment_name}'" + experiments = self.mlflow.search_experiments(filter_string=exp_filter_string) + if not experiments: + self.mlflow.create_experiment(name=experiment_name) + + self.mlflow.set_experiment(experiment_name) + self.run = self.mlflow.start_run(run_name=run_name, nested=False) + + return self.mlflow + + def close(self) -> None: + """Closes the MLflow run.""" + if self.run: + self.mlflow.end_run() + + def write(self, result: OptimizeResult) -> None: + """ + Logs the results of the current iteration to MLflow. + + Args: + result (OptimizeResult): The optimization results to log. + """ + # Not writing loss as loss is available in the metrics + # if result.loss is not None: + # self.mlflow.log_metric("loss", float(result.loss), step=result.iteration) + if self.mlflow: + self.mlflow.log_metrics(result.metrics, step=result.iteration) + else: + raise RuntimeError( + "The writer is not initialized." + "Please call the 'writer.open()' method before writing" + ) + + def log_hyperparams(self, hyperparams: dict) -> None: + """ + Logs hyperparameters to MLflow. + + Args: + hyperparams (dict): A dictionary of hyperparameters to log. + """ + if self.mlflow: + self.mlflow.log_params(hyperparams) + else: + raise RuntimeError( + "The writer is not initialized." + "Please call the 'writer.open()' method before writing" + ) + + def plot( + self, + model: Module, + iteration: int, + plotting_functions: tuple[PlottingFunction, ...], + ) -> None: + """ + Logs plots of the model using provided plotting functions. + + Args: + model (Module): The model to plot. + iteration (int): The current iteration number. + plotting_functions (tuple[PlottingFunction, ...]): Functions used + to generate plots. + """ + if self.mlflow: + for pf in plotting_functions: + descr, fig = pf(model, iteration) + self.mlflow.log_figure(fig, descr) + else: + raise RuntimeError( + "The writer is not initialized." + "Please call the 'writer.open()' method before writing" + ) + + def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader | None) -> Any: + """ + Infers the signature of the model based on the input data from the dataloader. + + Args: + model (Module): The model to use for inference. + dataloader (DataLoader | None): DataLoader for model inputs. + + Returns: + Optional[Any]: The inferred signature, if available. + """ + if dataloader is None: + return None + + xs: InputData + xs, *_ = next(iter(dataloader)) + preds = model(xs) + + if isinstance(xs, Tensor): + xs = xs.detach().cpu().numpy() + preds = preds.detach().cpu().numpy() + return infer_signature(xs, preds) + + return None + + def log_model( + self, + model: Module, + train_dataloader: DataLoader | None = None, + val_dataloader: DataLoader | None = None, + test_dataloader: DataLoader | None = None, + ) -> None: + """ + Logs the model and its signature to MLflow using the provided data loaders. + + Args: + model (Module): The model to log. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. + """ + if not self.mlflow: + raise RuntimeError( + "The writer is not initialized." + "Please call the 'writer.open()' method before writing" + ) + + signatures = self.get_signature_from_dataloader(model, train_dataloader) + self.mlflow.pytorch.log_model(model, artifact_path="model", signature=signatures) + + +# Writer registry +WRITER_REGISTRY = { + ExperimentTrackingTool.TENSORBOARD: TensorBoardWriter, + ExperimentTrackingTool.MLFLOW: MLFlowWriter, +} + + +def get_writer(tracking_tool: ExperimentTrackingTool) -> BaseWriter: + """Factory method to get the appropriate writer based on the tracking tool. + + Args: + tracking_tool (ExperimentTrackingTool): The experiment tracking tool to use. + + Returns: + BaseWriter: An instance of the appropriate writer. + """ + writer_class = WRITER_REGISTRY.get(tracking_tool) + if writer_class: + return writer_class() + else: + raise ValueError(f"Unsupported tracking tool: {tracking_tool}") diff --git a/qadence/ml_tools/config.py b/qadence/ml_tools/config.py index 27a68281c..c33649753 100644 --- a/qadence/ml_tools/config.py +++ b/qadence/ml_tools/config.py @@ -1,19 +1,14 @@ from __future__ import annotations -import datetime -import os from dataclasses import dataclass, field, fields from logging import getLogger from pathlib import Path -from typing import Any, Callable, Type -from uuid import uuid4 +from typing import Callable, Type from sympy import Basic -from torch import Tensor from qadence.blocks.analog import AnalogBlock from qadence.blocks.primitive import ParametricBlock -from qadence.ml_tools.data import OptimizeResult from qadence.operations import RX, AnalogRX from qadence.parameters import Parameter from qadence.types import ( @@ -28,306 +23,185 @@ logger = getLogger(__file__) -CallbackFunction = Callable[[OptimizeResult], None] -CallbackConditionFunction = Callable[[OptimizeResult], bool] - - -class Callback: - """Callback functions are calling in train functions. - - Each callback function should take at least as first input - an OptimizeResult instance. - - Note: when setting call_after_opt to True, we skip - verifying iteration % called_every == 0. - - Attributes: - callback (CallbackFunction): Callback function accepting an - OptimizeResult as first argument. - callback_condition (CallbackConditionFunction | None, optional): Function that - conditions the call to callback. Defaults to None. - modify_optimize_result (CallbackFunction | dict[str, Any] | None, optional): - Function that modify the OptimizeResult before callback. - For instance, one can change the `extra` (dict) argument to be used in callback. - If a dict is provided, the `extra` field of OptimizeResult is updated with the dict. - called_every (int, optional): Callback to be called each `called_every` epoch. - Defaults to 1. - If callback_condition is None, we set - callback_condition to returns True when iteration % called_every == 0. - call_before_opt (bool, optional): If true, callback is applied before training. - Defaults to False. - call_end_epoch (bool, optional): If true, callback is applied during training, - after an epoch is performed. Defaults to True. - call_after_opt (bool, optional): If true, callback is applied after training. - Defaults to False. - call_during_eval (bool, optional): If true, callback is applied during evaluation. - Defaults to False. - """ - - def __init__( - self, - callback: CallbackFunction, - callback_condition: CallbackConditionFunction | None = None, - modify_optimize_result: CallbackFunction | dict[str, Any] | None = None, - called_every: int = 1, - call_before_opt: bool = False, - call_end_epoch: bool = True, - call_after_opt: bool = False, - call_during_eval: bool = False, - ) -> None: - """Initialized Callback. - - Args: - callback (CallbackFunction): Callback function accepting an - OptimizeResult as ifrst argument. - callback_condition (CallbackConditionFunction | None, optional): Function that - conditions the call to callback. Defaults to None. - modify_optimize_result (CallbackFunction | dict[str, Any] | None , optional): - Function that modify the OptimizeResult before callback. If a dict - is provided, this updates the `extra` field of OptimizeResult. - called_every (int, optional): Callback to be called each `called_every` epoch. - Defaults to 1. - If callback_condition is None, we set - callback_condition to returns True when iteration % called_every == 0. - call_before_opt (bool, optional): If true, callback is applied before training. - Defaults to False. - call_end_epoch (bool, optional): If true, callback is applied during training, - after an epoch is performed. Defaults to True. - call_after_opt (bool, optional): If true, callback is applied after training. - Defaults to False. - call_during_eval (bool, optional): If true, callback is applied during evaluation. - Defaults to False. - """ - self.callback = callback - self.call_before_opt = call_before_opt - self.call_end_epoch = call_end_epoch - self.call_after_opt = call_after_opt - self.call_during_eval = call_during_eval - - if called_every <= 0: - raise ValueError("Please provide a strictly positive `called_every` argument.") - self.called_every = called_every - - if callback_condition is None: - self.callback_condition = lambda opt_result: True - else: - self.callback_condition = callback_condition - - if modify_optimize_result is None: - self.modify_optimize_result = lambda opt_result: opt_result - elif isinstance(modify_optimize_result, dict): - - def update_extra(opt_result: OptimizeResult) -> OptimizeResult: - opt_result.extra.update(modify_optimize_result) - return opt_result - - self.modify_optimize_result = update_extra - else: - self.modify_optimize_result = modify_optimize_result - - def __call__(self, opt_result: OptimizeResult, is_last_iteration: bool = False) -> Any: - """Apply callback if conditions are met. - - Note that the current result may be modified by specifying a function - `modify_optimize_result` for instance to add inputs to the `extra` argument - of the current OptimizeResult. - - Args: - opt_result (OptimizeResult): Current result. - is_last_iteration (bool, optional): When True, - avoid verifying modulo. Defaults to False. - Useful when call_after_opt is True. - - Returns: - Any: The result of the callback. - """ - opt_result = self.modify_optimize_result(opt_result) - if opt_result.iteration % self.called_every == 0 and self.callback_condition(opt_result): - return self.callback(opt_result) - if is_last_iteration and self.callback_condition(opt_result): - return self.callback(opt_result) - - -def run_callbacks( - callback_iterable: list[Callback], opt_res: OptimizeResult, is_last_iteration: bool = False -) -> None: - """Run a list of Callback given the current OptimizeResult. - - Used in train functions. - - Args: - callback_iterable (list[Callback]): Iterable of Callbacks - opt_res (OptimizeResult): Current optimization result, - is_last_iteration (bool, optional): Whether we reached the last iteration or not. - Defaults to False. - """ - for callback in callback_iterable: - callback(opt_res, is_last_iteration) - @dataclass class TrainConfig: - """Default config for the train function. + """Default configuration for the training process. - The default value of - each field can be customized with the constructor: + This class provides default settings for various aspects of the training loop, + such as logging, checkpointing, and validation. The default values for these + fields can be customized when an instance of `TrainConfig` is created. + Example: ```python exec="on" source="material-block" result="json" from qadence.ml_tools import TrainConfig - c = TrainConfig(folder="/tmp/train") + c = TrainConfig(root_folder="/tmp/train") print(str(c)) # markdown-exec: hide ``` """ max_iter: int = 10000 - """Number of training iterations.""" - print_every: int = 1000 - """Print loss/metrics. + """Number of training iterations (epochs) to perform. + + This defines the total number + of times the model will be updated. - Set to 0 to disable + In case of InfiniteTensorDataset, each epoch will have 1 batch. + In case of TensorDataset, each epoch will have len(dataloader) batches. """ - write_every: int = 50 - """Write loss and metrics with the tracking tool. - Set to 0 to disable + print_every: int = 0 + """Frequency (in epochs) for printing loss and metrics to the console during training. + + Set to 0 to disable this output, meaning that metrics and loss will not be printed + during training. """ - checkpoint_every: int = 5000 - """Write model/optimizer checkpoint. - Set to 0 to disable + write_every: int = 0 + """Frequency (in epochs) for writing loss and metrics using the tracking tool during training. + + Set to 0 to disable this logging, which prevents metrics from being logged to the tracking tool. + Note that the metrics will always be written at the end of training regardless of this setting. """ - plot_every: int = 5000 - """Write figures. - Set to 0 to disable + checkpoint_every: int = 0 + """Frequency (in epochs) for saving model and optimizer checkpoints during training. + + Set to 0 to disable checkpointing. This helps in resuming training or recovering + models. + Note that setting checkpoint_best_only = True will disable this and only best checkpoints will + be saved. + """ + + plot_every: int = 0 + """Frequency (in epochs) for generating and saving figures during training. + + Set to 0 to disable plotting. """ - callbacks: list[Callback] = field(default_factory=lambda: list()) - """List of callbacks.""" + + callbacks: list = field(default_factory=lambda: list()) + """List of callbacks to execute during training. + + Callbacks can be used for + custom behaviors, such as early stopping, custom logging, or other actions + triggered at specific events. + """ + log_model: bool = False - """Logs a serialised version of the model.""" - folder: Path | None = None - """Checkpoint/tensorboard logs folder.""" + """Whether to log a serialized version of the model. + + When set to `True`, the + model's state will be logged, useful for model versioning and reproducibility. + """ + + root_folder: Path = Path("./qml_logs") + """The root folder for saving checkpoints and tensorboard logs. + + The default path is "./qml_logs" + + This can be set to a specific directory where training artifacts are to be stored. + Checkpoints will be saved inside a subfolder in this directory. Subfolders will be + created based on `create_subfolder_per_run` argument. + """ + create_subfolder_per_run: bool = False - """Checkpoint/tensorboard logs stored in subfolder with name `_`. + """Whether to create a subfolder for each run, named `__`. + + This ensures logs and checkpoints from different runs do not overwrite each other, + which is helpful for rapid prototyping. If `False`, training will resume from + the latest checkpoint if one exists in the specified log folder. + """ + + log_folder: Path = Path("./") + """The log folder for saving checkpoints and tensorboard logs. - Prevents continuing from previous checkpoint, useful for fast prototyping. + This stores the path where all logs and checkpoints are being saved + for this training session. `log_folder` takes precedence over `root_folder` and + `create_subfolder_per_run` arguments. If the user specifies a log_folder, + all checkpoints will be saved in this folder and `root_folder` argument + will not be used. """ + checkpoint_best_only: bool = False - """Write model/optimizer checkpoint only if a metric has improved.""" - val_every: int | None = None - """Calculate validation metric. + """If `True`, checkpoints are only saved if there is an improvement in the. - If None, validation check is not performed. + validation metric. This conserves storage by only keeping the best models. + + validation_criterion is required when this is set to True. """ + + val_every: int = 0 + """Frequency (in epochs) for performing validation. + + If set to 0, validation is not performed. + Note that metrics from validation are always written, regardless of the `write_every` setting. + Note that initial validation happens at the start of training (when val_every > 0) + For initial validation - initial metrics are written. + - checkpoint is saved (when checkpoint_best_only = False) + """ + val_epsilon: float = 1e-5 - """Safety margin to check if validation loss is smaller than the lowest. + """A small safety margin used to compare the current validation loss with the. - validation loss across previous iterations. + best previous validation loss. This is used to determine improvements in metrics. """ + validation_criterion: Callable | None = None - """A boolean function which evaluates a given validation metric is satisfied.""" + """A function to evaluate whether a given validation metric meets a desired condition. + + The validation_criterion has the following format: + def validation_criterion(val_loss: float, best_val_loss: float, val_epsilon: float) -> bool: + # process + + If `None`, no custom validation criterion is applied. + """ + trainstop_criterion: Callable | None = None - """A boolean function which evaluates a given training stopping metric is satisfied.""" - batch_size: int = 1 - """The batch_size to use when passing a list/tuple of torch.Tensors.""" - verbose: bool = True - """Whether or not to print out metrics values during training.""" - tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD - """The tracking tool of choice.""" - hyperparams: dict = field(default_factory=dict) - """Hyperparameters to track.""" - plotting_functions: tuple[LoggablePlotFunction, ...] = field(default_factory=tuple) # type: ignore - """Functions for in-train plotting.""" - - # tensorboard only allows for certain types as hyperparameters - _tb_allowed_hyperparams_types: tuple = field( - default=(int, float, str, bool, Tensor), init=False, repr=False - ) - - def _filter_tb_hyperparams(self) -> None: - keys_to_remove = [ - key - for key, value in self.hyperparams.items() - if not isinstance(value, TrainConfig._tb_allowed_hyperparams_types) - ] - if keys_to_remove: - logger.warning( - f"Tensorboard cannot log the following hyperparameters: {keys_to_remove}." - ) - for key in keys_to_remove: - self.hyperparams.pop(key) + """A function to determine if the training process should stop based on a. - def __post_init__(self) -> None: - if self.folder: - if isinstance(self.folder, str): # type: ignore [unreachable] - self.folder = Path(self.folder) # type: ignore [unreachable] - if self.create_subfolder_per_run: - subfoldername = ( - datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + "_" + hex(os.getpid())[2:] - ) - self.folder = self.folder / subfoldername - if self.trainstop_criterion is None: - self.trainstop_criterion = lambda x: x <= self.max_iter - if self.validation_criterion is None: - self.validation_criterion = lambda *x: False - if self.hyperparams and self.tracking_tool == ExperimentTrackingTool.TENSORBOARD: - self._filter_tb_hyperparams() - if self.tracking_tool == ExperimentTrackingTool.MLFLOW: - self._mlflow_config = MLFlowConfig() - if self.plotting_functions and self.tracking_tool != ExperimentTrackingTool.MLFLOW: - logger.warning("In-training plots are only available with mlflow tracking.") - if not self.plotting_functions and self.tracking_tool == ExperimentTrackingTool.MLFLOW: - logger.warning("Tracking with mlflow, but no plotting functions provided.") - - @property - def mlflow_config(self) -> MLFlowConfig: - if self.tracking_tool == ExperimentTrackingTool.MLFLOW: - return self._mlflow_config - else: - raise AttributeError( - "mlflow_config is available only for with the mlflow tracking tool." - ) + specific stopping metric. If `None`, training continues until `max_iter` is reached. + """ + batch_size: int = 1 + """The batch size to use when processing a list or tuple of torch.Tensors. -class MLFlowConfig: + This specifies how many samples are processed in each training iteration. """ - Configuration for mlflow tracking. - Example: + verbose: bool = True + """Whether to print metrics and status messages during training. - export MLFLOW_TRACKING_URI=tracking_uri - export MLFLOW_EXPERIMENT=experiment_name - export MLFLOW_RUN_NAME=run_name + If `True`, detailed metrics and status updates will be displayed in the console. """ - def __init__(self) -> None: - import mlflow + tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD + """The tool used for tracking training progress and logging metrics. - self.tracking_uri: str = os.getenv("MLFLOW_TRACKING_URI", "") - """The URI of the mlflow tracking server. + Options include tools like TensorBoard, which help visualize and monitor + model training. + """ - An empty string, or a local file path, prefixed with file:/. - Data is stored locally at the provided file (or ./mlruns if empty). - """ + hyperparams: dict = field(default_factory=dict) + """A dictionary of hyperparameters to be tracked. - self.experiment_name: str = os.getenv("MLFLOW_EXPERIMENT", str(uuid4())) - """The name of the experiment. + This can include learning rates, + regularization parameters, or any other training-related configurations. + """ - If None or empty, a new experiment is created with a random UUID. - """ + plotting_functions: tuple[LoggablePlotFunction, ...] = field(default_factory=tuple) # type: ignore + """Functions used for in-training plotting. - self.run_name: str = os.getenv("MLFLOW_RUN_NAME", str(uuid4())) - """The name of the run.""" + These are called to generate + plots that are logged or saved at specified intervals. + """ - mlflow.set_tracking_uri(self.tracking_uri) + _subfolders: list = field(default_factory=list) + """List of subfolders used for logging different runs using the same config inside the. - # activate existing or create experiment - exp_filter_string = f"name = '{self.experiment_name}'" - if not mlflow.search_experiments(filter_string=exp_filter_string): - mlflow.create_experiment(name=self.experiment_name) + root folder. - self.experiment = mlflow.set_experiment(self.experiment_name) - self.run = mlflow.start_run(run_name=self.run_name, nested=False) + Each subfolder is of structure `__`. + """ @dataclass diff --git a/qadence/ml_tools/data.py b/qadence/ml_tools/data.py index 8e92e7f94..877b8f27f 100644 --- a/qadence/ml_tools/data.py +++ b/qadence/ml_tools/data.py @@ -1,8 +1,8 @@ from __future__ import annotations +import random from dataclasses import dataclass, field from functools import singledispatch -from itertools import cycle from typing import Any, Iterator from nevergrad.optimization.base import Optimizer as NGOptimizer @@ -72,13 +72,17 @@ def __init__(self, *tensors: Tensor): ``` """ self.tensors = tensors + self.indices = list(range(self.tensors[0].size(0))) def __iter__(self) -> Iterator: if len(set([t.size(0) for t in self.tensors])) != 1: raise ValueError("Size of first dimension must be the same for all tensors.") - for idx in cycle(range(self.tensors[0].size(0))): - yield tuple(t[idx] for t in self.tensors) + # Shuffle the indices for every full pass + random.shuffle(self.indices) + while True: + for idx in self.indices: + yield tuple(t[idx] for t in self.tensors) def to_dataloader(*tensors: Tensor, batch_size: int = 1, infinite: bool = False) -> DataLoader: diff --git a/qadence/ml_tools/loss/__init__.py b/qadence/ml_tools/loss/__init__.py new file mode 100644 index 000000000..cc8bbfc16 --- /dev/null +++ b/qadence/ml_tools/loss/__init__.py @@ -0,0 +1,10 @@ +from __future__ import annotations + +from .loss import cross_entropy_loss, get_loss_fn, mse_loss + +# Modules to be automatically added to the qadence.ml_tools.loss namespace +__all__ = [ + "cross_entropy_loss", + "get_loss_fn", + "mse_loss", +] diff --git a/qadence/ml_tools/loss/loss.py b/qadence/ml_tools/loss/loss.py new file mode 100644 index 000000000..d1bff72c2 --- /dev/null +++ b/qadence/ml_tools/loss/loss.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +from typing import Callable + +import torch +import torch.nn as nn + + +def mse_loss( + model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor] +) -> tuple[torch.Tensor, dict[str, float]]: + """Computes the Mean Squared Error (MSE) loss between model predictions and targets. + + Args: + model (nn.Module): The PyTorch model used for generating predictions. + batch (Tuple[torch.Tensor, torch.Tensor]): A tuple containing: + - inputs (torch.Tensor): The input data. + - targets (torch.Tensor): The ground truth labels. + + Returns: + Tuple[torch.Tensor, dict[str, float]]: + - loss (torch.Tensor): The computed MSE loss value. + - metrics (dict[str, float]): A dictionary with the MSE loss value. + """ + criterion = nn.MSELoss() + inputs, targets = batch + outputs = model(inputs) + loss = criterion(outputs, targets) + + metrics = {"mse": loss} + return loss, metrics + + +def cross_entropy_loss( + model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor] +) -> tuple[torch.Tensor, dict[str, float]]: + """Computes the Cross Entropy loss between model predictions and targets. + + Args: + model (nn.Module): The PyTorch model used for generating predictions. + batch (Tuple[torch.Tensor, torch.Tensor]): A tuple containing: + - inputs (torch.Tensor): The input data. + - targets (torch.Tensor): The ground truth labels. + + Returns: + Tuple[torch.Tensor, dict[str, float]]: + - loss (torch.Tensor): The computed Cross Entropy loss value. + - metrics (dict[str, float]): A dictionary with the Cross Entropy loss value. + """ + criterion = nn.CrossEntropyLoss() + inputs, targets = batch + outputs = model(inputs) + loss = criterion(outputs, targets) + + metrics = {"cross_entropy": loss} + return loss, metrics + + +def get_loss_fn(loss_fn: str | Callable | None) -> Callable: + """ + Returns the appropriate loss function based on the input argument. + + Args: + loss_fn (str | Callable | None): The loss function to use. + - If `loss_fn` is a callable, it will be returned directly. + - If `loss_fn` is a string, it should be one of: + - "mse": Returns the `mse_loss` function. + - "cross_entropy": Returns the `cross_entropy_loss` function. + - If `loss_fn` is `None`, the default `mse_loss` function will be returned. + + Returns: + Callable: The corresponding loss function. + + Raises: + ValueError: If `loss_fn` is a string but not a supported loss function name. + """ + if callable(loss_fn): + return loss_fn + elif isinstance(loss_fn, str): + if loss_fn == "mse": + return mse_loss + elif loss_fn == "cross_entropy": + return cross_entropy_loss + else: + raise ValueError(f"Unsupported loss function: {loss_fn}") + else: + return mse_loss diff --git a/qadence/ml_tools/optimize_step.py b/qadence/ml_tools/optimize_step.py index 93bd9ac5c..b59675e1a 100644 --- a/qadence/ml_tools/optimize_step.py +++ b/qadence/ml_tools/optimize_step.py @@ -2,11 +2,14 @@ from typing import Any, Callable +import nevergrad as ng import torch from torch.nn import Module from torch.optim import Optimizer from qadence.ml_tools.data import data_to_device +from qadence.ml_tools.parameters import set_parameters +from qadence.ml_tools.tensors import promote_to_tensor def optimize_step( @@ -19,21 +22,21 @@ def optimize_step( ) -> tuple[torch.Tensor | float, dict | None]: """Default Torch optimize step with closure. - This is the default optimization step which should work for most - of the standard use cases of optimization of Torch models + This is the default optimization step. Args: - model (Module): The input model - optimizer (Optimizer): The chosen Torch optimizer + model (Module): The input model to be optimized. + optimizer (Optimizer): The chosen Torch optimizer. loss_fn (Callable): A custom loss function - xs (dict | list | torch.Tensor | None): the input data. If None it means - that the given model does not require any input data - device (torch.device): A target device to run computation on. - dtype (torch.dtype): Data type for xs conversion. + that returns the loss value and a dictionary of metrics. + xs (dict | list | Tensor | None): The input data. If None, it means + the given model does not require any input data. + device (torch.device): A target device to run computations on. + dtype (torch.dtype): Data type for `xs` conversion. Returns: - tuple: tuple containing the computed loss value, and a dictionary with - the collected metrics. + tuple[Tensor | float, dict | None]: A tuple containing the computed loss value + and a dictionary with collected metrics. """ loss, metrics = None, {} @@ -52,3 +55,35 @@ def closure() -> Any: optimizer.step(closure) # return the loss/metrics that are being mutated inside the closure... return loss, metrics + + +def update_ng_parameters( + model: Module, + optimizer: ng.optimizers.Optimizer, + loss_fn: Callable[[Module, torch.Tensor | None], tuple[float, dict]], + data: torch.Tensor | None, + ng_params: ng.p.Array, +) -> tuple[float, dict, ng.p.Array]: + """Update the model parameters using Nevergrad. + + This function integrates Nevergrad for derivative-free optimization. + + Args: + model (Module): The PyTorch model to be optimized. + optimizer (ng.optimizers.Optimizer): A Nevergrad optimizer instance. + loss_fn (Callable[[Module, Tensor | None], tuple[float, dict]]): A custom loss function + that returns the loss value and a dictionary of metrics. + data (Tensor | None): Input data for the model. If None, it means the model does + not require input data. + ng_params (ng.p.Array): The current set of parameters managed by Nevergrad. + + Returns: + tuple[float, dict, ng.p.Array]: A tuple containing the computed loss value, + a dictionary of metrics, and the updated Nevergrad parameters. + """ + loss, metrics = loss_fn(model, data) # type: ignore[misc] + optimizer.tell(ng_params, float(loss)) + ng_params = optimizer.ask() # type: ignore[assignment] + params = promote_to_tensor(ng_params.value, requires_grad=False) + set_parameters(model, params) + return loss, metrics, ng_params diff --git a/qadence/ml_tools/printing.py b/qadence/ml_tools/printing.py deleted file mode 100644 index 89bfb1f14..000000000 --- a/qadence/ml_tools/printing.py +++ /dev/null @@ -1,154 +0,0 @@ -from __future__ import annotations - -from logging import getLogger -from typing import Any, Callable, Union - -from matplotlib.figure import Figure -from torch import Tensor -from torch.nn import Module -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter - -from qadence.ml_tools.data import DictDataLoader -from qadence.types import ExperimentTrackingTool - -logger = getLogger(__name__) - -PlottingFunction = Callable[[Module, int], tuple[str, Figure]] -InputData = Union[Tensor, dict[str, Tensor]] - - -def print_metrics(loss: float | None, metrics: dict, iteration: int) -> None: - msg = " ".join( - [f"Iteration {iteration: >7} | Loss: {loss:.7f} -"] - + [f"{k}: {v.item():.7f}" for k, v in metrics.items()] - ) - print(msg) - - -def write_tensorboard( - writer: SummaryWriter, loss: float = None, metrics: dict | None = None, iteration: int = 0 -) -> None: - metrics = metrics or dict() - if loss is not None: - writer.add_scalar("loss", loss, iteration) - for key, arg in metrics.items(): - writer.add_scalar(key, arg, iteration) - - -def log_hyperparams_tensorboard(writer: SummaryWriter, hyperparams: dict, metrics: dict) -> None: - writer.add_hparams(hyperparams, metrics) - - -def plot_tensorboard( - writer: SummaryWriter, - model: Module, - iteration: int, - plotting_functions: tuple[PlottingFunction], -) -> None: - for pf in plotting_functions: - descr, fig = pf(model, iteration) - writer.add_figure(descr, fig, global_step=iteration) - - -def log_model_tensorboard( - writer: SummaryWriter, - model: Module, - dataloader: Union[None, DataLoader, DictDataLoader], -) -> None: - logger.warning("Model logging is not supported by tensorboard. No model will be logged.") - - -def write_mlflow(writer: Any, loss: float | None, metrics: dict, iteration: int) -> None: - writer.log_metrics({"loss": float(loss)}, step=iteration) # type: ignore - writer.log_metrics(metrics, step=iteration) # logs the single metrics - - -def log_hyperparams_mlflow(writer: Any, hyperparams: dict, metrics: dict) -> None: - writer.log_params(hyperparams) # type: ignore - - -def plot_mlflow( - writer: Any, - model: Module, - iteration: int, - plotting_functions: tuple[PlottingFunction], -) -> None: - for pf in plotting_functions: - descr, fig = pf(model, iteration) - writer.log_figure(fig, descr) - - -def log_model_mlflow( - writer: Any, model: Module, dataloader: DataLoader | DictDataLoader | None -) -> None: - signature = None - if dataloader is not None: - xs: InputData - xs, *_ = next(iter(dataloader)) - preds = model(xs) - if isinstance(xs, Tensor): - xs = xs.numpy() - preds = preds.detach().numpy() - elif isinstance(xs, dict): - for key, val in xs.items(): - xs[key] = val.numpy() - for key, val in preds.items(): - preds[key] = val.detach.numpy() - - try: - from mlflow.models import infer_signature - - signature = infer_signature(xs, preds) - except ImportError: - logger.warning( - "An MLFlow specific function has been called but MLFlow failed to import." - "Please install MLFlow or adjust your code." - ) - - writer.pytorch.log_model(model, artifact_path="model", signature=signature) - - -TRACKER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = { - ExperimentTrackingTool.TENSORBOARD: write_tensorboard, - ExperimentTrackingTool.MLFLOW: write_mlflow, -} - -LOGGER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = { - ExperimentTrackingTool.TENSORBOARD: log_hyperparams_tensorboard, - ExperimentTrackingTool.MLFLOW: log_hyperparams_mlflow, -} - -PLOTTER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = { - ExperimentTrackingTool.TENSORBOARD: plot_tensorboard, - ExperimentTrackingTool.MLFLOW: plot_mlflow, -} - -MODEL_LOGGER_MAPPING: dict[ExperimentTrackingTool, Callable[..., None]] = { - ExperimentTrackingTool.TENSORBOARD: log_model_tensorboard, - ExperimentTrackingTool.MLFLOW: log_model_mlflow, -} - - -def write_tracker( - *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD -) -> None: - return TRACKER_MAPPING[tracking_tool](*args) - - -def log_tracker( - *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD -) -> None: - return LOGGER_MAPPING[tracking_tool](*args) - - -def plot_tracker( - *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD -) -> None: - return PLOTTER_MAPPING[tracking_tool](*args) - - -def log_model_tracker( - *args: Any, tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD -) -> None: - return MODEL_LOGGER_MAPPING[tracking_tool](*args) diff --git a/qadence/ml_tools/stages.py b/qadence/ml_tools/stages.py new file mode 100644 index 000000000..aff30fad0 --- /dev/null +++ b/qadence/ml_tools/stages.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from qadence.types import StrEnum + + +class TrainingStage(StrEnum): + """Different stages in the training, validation, and testing process.""" + + IDLE = "idle" + """An 'idle' stage for scenarios where no training, validation, or testing is involved.""" + + TRAIN_START = "train_start" + """Marks the start of the training process.""" + + TRAIN_END = "train_end" + """Marks the end of the training process.""" + + TRAIN_EPOCH_START = "train_epoch_start" + """Indicates the start of a training epoch.""" + + TRAIN_EPOCH_END = "train_epoch_end" + """Indicates the end of a training epoch.""" + + TRAIN_BATCH_START = "train_batch_start" + """Marks the start of processing a training batch.""" + + TRAIN_BATCH_END = "train_batch_end" + """Marks the end of processing a training batch.""" + + VAL_EPOCH_START = "val_epoch_start" + """Indicates the start of a validation epoch.""" + + VAL_EPOCH_END = "val_epoch_end" + """Indicates the end of a validation epoch.""" + + VAL_BATCH_START = "val_batch_start" + """Marks the start of processing a validation batch.""" + + VAL_BATCH_END = "val_batch_end" + """Marks the end of processing a validation batch.""" + + TEST_BATCH_START = "test_batch_start" + """Marks the start of processing a test batch.""" + + TEST_BATCH_END = "test_batch_end" + """Marks the end of processing a test batch.""" diff --git a/qadence/ml_tools/train_grad.py b/qadence/ml_tools/train_grad.py deleted file mode 100644 index b94b04d7e..000000000 --- a/qadence/ml_tools/train_grad.py +++ /dev/null @@ -1,395 +0,0 @@ -from __future__ import annotations - -import importlib -import math -from logging import getLogger -from typing import Any, Callable, Union - -from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn -from torch import Tensor, complex128, float32, float64 -from torch import device as torch_device -from torch import dtype as torch_dtype -from torch.nn import DataParallel, Module -from torch.optim import Optimizer -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter - -from qadence.ml_tools.config import Callback, TrainConfig, run_callbacks -from qadence.ml_tools.data import DictDataLoader, OptimizeResult, data_to_device -from qadence.ml_tools.optimize_step import optimize_step -from qadence.ml_tools.printing import ( - log_model_tracker, - log_tracker, - plot_tracker, - print_metrics, - write_tracker, -) -from qadence.ml_tools.saveload import load_checkpoint, write_checkpoint -from qadence.types import ExperimentTrackingTool - -logger = getLogger(__name__) - - -def train( - model: Module, - dataloader: Union[None, DataLoader, DictDataLoader], - optimizer: Optimizer, - config: TrainConfig, - loss_fn: Callable, - device: torch_device = None, - optimize_step: Callable = optimize_step, - dtype: torch_dtype = None, -) -> tuple[Module, Optimizer]: - """Runs the training loop with gradient-based optimizer. - - Assumes that `loss_fn` returns a tuple of (loss, - metrics: dict), where `metrics` is a dict of scalars. Loss and metrics are - written to tensorboard. Checkpoints are written every - `config.checkpoint_every` steps (and after the last training step). If a - checkpoint is found at `config.folder` we resume training from there. The - tensorboard logs can be viewed via `tensorboard --logdir /path/to/folder`. - - Args: - model: The model to train. - dataloader: dataloader of different types. If None, no data is required by - the model - optimizer: The optimizer to use. - config: `TrainConfig` with additional training options. - loss_fn: Loss function returning (loss: float, metrics: dict[str, float], ...) - device: String defining device to train on, pass 'cuda' for GPU. - optimize_step: Customizable optimization callback which is called at every iteration.= - The function must have the signature `optimize_step(model, - optimizer, loss_fn, xs, device="cpu")`. - dtype: The dtype to use for the data. - - Example: - ```python exec="on" source="material-block" - from pathlib import Path - import torch - from itertools import count - from qadence import Parameter, QuantumCircuit, Z - from qadence import hamiltonian_factory, hea, feature_map, chain - from qadence import QNN - from qadence.ml_tools import TrainConfig, train_with_grad, to_dataloader - - n_qubits = 2 - fm = feature_map(n_qubits) - ansatz = hea(n_qubits=n_qubits, depth=3) - observable = hamiltonian_factory(n_qubits, detuning = Z) - circuit = QuantumCircuit(n_qubits, fm, ansatz) - - model = QNN(circuit, observable, backend="pyqtorch", diff_mode="ad") - batch_size = 1 - input_values = {"phi": torch.rand(batch_size, requires_grad=True)} - pred = model(input_values) - - ## lets prepare the train routine - - cnt = count() - criterion = torch.nn.MSELoss() - optimizer = torch.optim.Adam(model.parameters(), lr=0.1) - - def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: - next(cnt) - x, y = data[0], data[1] - out = model(x) - loss = criterion(out, y) - return loss, {} - - tmp_path = Path("/tmp") - n_epochs = 5 - batch_size = 25 - config = TrainConfig( - folder=tmp_path, - max_iter=n_epochs, - checkpoint_every=100, - write_every=100, - ) - x = torch.linspace(0, 1, batch_size).reshape(-1, 1) - y = torch.sin(x) - data = to_dataloader(x, y, batch_size=batch_size, infinite=True) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) - ``` - """ - # load available checkpoint - init_iter = 0 - log_device = "cpu" if device is None else device - if config.folder: - model, optimizer, init_iter = load_checkpoint( - config.folder, model, optimizer, device=log_device - ) - logger.debug(f"Loaded model and optimizer from {config.folder}") - - # Move model to device before optimizer is loaded - if isinstance(model, DataParallel): - model = model.module.to(device=device, dtype=dtype) - else: - model = model.to(device=device, dtype=dtype) - # initialize tracking tool - if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD: - writer = SummaryWriter(config.folder, purge_step=init_iter) - else: - writer = importlib.import_module("mlflow") - - perform_val = isinstance(config.val_every, int) - if perform_val: - if not isinstance(dataloader, DictDataLoader): - raise ValueError( - "If `config.val_every` is provided as an integer, dataloader must" - "be an instance of `DictDataLoader`." - ) - iter_keys = dataloader.dataloaders.keys() - if "train" not in iter_keys or "val" not in iter_keys: - raise ValueError( - "If `config.val_every` is provided as an integer, the dictdataloader" - "must have `train` and `val` keys to access the respective dataloaders." - ) - val_dataloader = dataloader.dataloaders["val"] - dataloader = dataloader.dataloaders["train"] - - ## Training - progress = Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - TimeRemainingColumn(elapsed_when_finished=True), - ) - data_dtype = None - if dtype: - data_dtype = float64 if dtype == complex128 else float32 - - best_val_loss = math.inf - - if not ((dataloader is None) or isinstance(dataloader, (DictDataLoader, DataLoader))): - raise NotImplementedError( - f"Unsupported dataloader type: {type(dataloader)}. " - "You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader." - ) - - def next_loss_iter(dl_iter: Union[None, DataLoader, DictDataLoader]) -> Any: - """Get loss on the next batch of a dataloader. - - loaded on device if not None. - - Args: - dl_iter (Union[None, DataLoader, DictDataLoader]): Dataloader. - - Returns: - Any: Loss value - """ - xs = next(dl_iter) if dl_iter is not None else None - xs_to_device = data_to_device(xs, device=device, dtype=data_dtype) - return loss_fn(model, xs_to_device) - - # populate callbacks with already available internal functions - # printing, writing and plotting - callbacks = config.callbacks - - # printing - if config.verbose and config.print_every > 0: - # Note that the loss returned by optimize_step - # is the value before doing the training step - # which is printed accordingly by the previous iteration number - callbacks += [ - Callback( - lambda opt_res: print_metrics(opt_res.loss, opt_res.metrics, opt_res.iteration - 1), - called_every=config.print_every, - ) - ] - - # plotting - callbacks += [ - Callback( - lambda opt_res: plot_tracker( - writer, - opt_res.model, - opt_res.iteration, - config.plotting_functions, - tracking_tool=config.tracking_tool, - ), - called_every=config.plot_every, - call_before_opt=True, - ) - ] - - # writing metrics - # we specify two writers, - # to write at evaluation time and before evaluation - callbacks += [ - Callback( - lambda opt_res: write_tracker( - writer, - opt_res.loss, - opt_res.metrics, - opt_res.iteration - 1, # loss returned be optimized_step is at -1 - tracking_tool=config.tracking_tool, - ), - called_every=config.write_every, - call_end_epoch=True, - ), - Callback( - lambda opt_res: write_tracker( - writer, - opt_res.loss, - opt_res.metrics, - opt_res.iteration, # after_opt we match the right loss function - tracking_tool=config.tracking_tool, - ), - called_every=config.write_every, - call_end_epoch=False, - call_after_opt=True, - ), - ] - if perform_val: - callbacks += [ - Callback( - lambda opt_res: write_tracker( - writer, - None, - opt_res.metrics, - opt_res.iteration, - tracking_tool=config.tracking_tool, - ), - called_every=config.write_every, - call_before_opt=True, - call_during_eval=True, - ) - ] - - # checkpointing - if config.folder and config.checkpoint_every > 0 and not config.checkpoint_best_only: - callbacks += [ - Callback( - lambda opt_res: write_checkpoint( - config.folder, # type: ignore[arg-type] - opt_res.model, - opt_res.optimizer, - opt_res.iteration, - ), - called_every=config.checkpoint_every, - call_before_opt=False, - call_after_opt=True, - ) - ] - - if config.folder and config.checkpoint_best_only: - callbacks += [ - Callback( - lambda opt_res: write_checkpoint( - config.folder, # type: ignore[arg-type] - opt_res.model, - opt_res.optimizer, - "best", - ), - called_every=config.checkpoint_every, - call_before_opt=True, - call_after_opt=True, - call_during_eval=True, - ) - ] - - callbacks_before_opt = [ - callback - for callback in callbacks - if callback.call_before_opt and not callback.call_during_eval - ] - callbacks_before_opt_eval = [ - callback for callback in callbacks if callback.call_before_opt and callback.call_during_eval - ] - - with progress: - dl_iter = iter(dataloader) if dataloader is not None else None - - # Initial validation evaluation - try: - opt_result = OptimizeResult(init_iter, model, optimizer) - if perform_val: - dl_iter_val = iter(val_dataloader) if val_dataloader is not None else None - best_val_loss, metrics, *_ = next_loss_iter(dl_iter_val) - metrics["val_loss"] = best_val_loss - opt_result.metrics = metrics - run_callbacks(callbacks_before_opt_eval, opt_result) - - run_callbacks(callbacks_before_opt, opt_result) - - except KeyboardInterrupt: - logger.info("Terminating training gracefully after the current iteration.") - - # outer epoch loop - init_iter += 1 - callbacks_end_epoch = [ - callback - for callback in callbacks - if callback.call_end_epoch and not callback.call_during_eval - ] - callbacks_end_epoch_eval = [ - callback - for callback in callbacks - if callback.call_end_epoch and callback.call_during_eval - ] - for iteration in progress.track(range(init_iter, init_iter + config.max_iter)): - try: - # in case there is not data needed by the model - # this is the case, for example, of quantum models - # which do not have classical input data (e.g. chemistry) - loss, metrics = optimize_step( - model=model, - optimizer=optimizer, - loss_fn=loss_fn, - xs=None if dataloader is None else next(dl_iter), # type: ignore[arg-type] - device=device, - dtype=data_dtype, - ) - if isinstance(loss, Tensor): - loss = loss.item() - opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics) - run_callbacks(callbacks_end_epoch, opt_result) - - if perform_val: - if iteration % config.val_every == 0: - val_loss, *_ = next_loss_iter(dl_iter_val) - if config.validation_criterion(val_loss, best_val_loss, config.val_epsilon): # type: ignore[misc] - best_val_loss = val_loss - metrics["val_loss"] = val_loss - opt_result.metrics = metrics - - run_callbacks(callbacks_end_epoch_eval, opt_result) - - except KeyboardInterrupt: - logger.info("Terminating training gracefully after the current iteration.") - break - - # For handling printing/writing the last training loss - # as optimize_step does not give the loss value at the last iteration - try: - loss, metrics, *_ = next_loss_iter(dl_iter) - if isinstance(loss, Tensor): - loss = loss.item() - if perform_val: - # reputting val_loss as already evaluated before - metrics["val_loss"] = val_loss - print_metrics(loss, metrics, iteration) - - except KeyboardInterrupt: - logger.info("Terminating training gracefully after the current iteration.") - - # Final callbacks, by default checkpointing and writing - opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics) - callbacks_after_opt = [callback for callback in callbacks if callback.call_after_opt] - run_callbacks(callbacks_after_opt, opt_result, is_last_iteration=True) - - # writing hyperparameters - if config.hyperparams: - log_tracker(writer, config.hyperparams, metrics, tracking_tool=config.tracking_tool) - - # logging the model - if config.log_model: - log_model_tracker(writer, model, dataloader, tracking_tool=config.tracking_tool) - - # close tracker - if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD: - writer.close() - elif config.tracking_tool == ExperimentTrackingTool.MLFLOW: - writer.end_run() - - return model, optimizer diff --git a/qadence/ml_tools/train_no_grad.py b/qadence/ml_tools/train_no_grad.py deleted file mode 100644 index af1f2255f..000000000 --- a/qadence/ml_tools/train_no_grad.py +++ /dev/null @@ -1,199 +0,0 @@ -from __future__ import annotations - -import importlib -from logging import getLogger -from typing import Callable - -import nevergrad as ng -from nevergrad.optimization.base import Optimizer as NGOptimizer -from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn -from torch import Tensor -from torch.nn import Module -from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter - -from qadence.ml_tools.config import Callback, TrainConfig, run_callbacks -from qadence.ml_tools.data import DictDataLoader, OptimizeResult -from qadence.ml_tools.parameters import get_parameters, set_parameters -from qadence.ml_tools.printing import ( - log_model_tracker, - log_tracker, - plot_tracker, - print_metrics, - write_tracker, -) -from qadence.ml_tools.saveload import load_checkpoint, write_checkpoint -from qadence.ml_tools.tensors import promote_to_tensor -from qadence.types import ExperimentTrackingTool - -logger = getLogger(__name__) - - -def train( - model: Module, - dataloader: DictDataLoader | DataLoader | None, - optimizer: NGOptimizer, - config: TrainConfig, - loss_fn: Callable, -) -> tuple[Module, NGOptimizer]: - """Runs the training loop with a gradient-free optimizer. - - Assumes that `loss_fn` returns a tuple of (loss, metrics: dict), where - `metrics` is a dict of scalars. Loss and metrics are written to - tensorboard. Checkpoints are written every `config.checkpoint_every` steps - (and after the last training step). If a checkpoint is found at `config.folder` - we resume training from there. The tensorboard logs can be viewed via - `tensorboard --logdir /path/to/folder`. - - Args: - model: The model to train - dataloader: Dataloader constructed via `dictdataloader` - optimizer: The optimizer to use taken from the Nevergrad library. If this is not - the case the function will raise an AssertionError - config: `TrainConfig` with additional training options. - loss_fn: Loss function returning (loss: float, metrics: dict[str, float]) - """ - init_iter = 0 - if config.folder: - model, optimizer, init_iter = load_checkpoint(config.folder, model, optimizer) - logger.debug(f"Loaded model and optimizer from {config.folder}") - - def _update_parameters( - data: Tensor | None, ng_params: ng.p.Array - ) -> tuple[float, dict, ng.p.Array]: - loss, metrics = loss_fn(model, data) # type: ignore[misc] - optimizer.tell(ng_params, float(loss)) - ng_params = optimizer.ask() # type: ignore [assignment] - params = promote_to_tensor(ng_params.value, requires_grad=False) - set_parameters(model, params) - return loss, metrics, ng_params - - assert loss_fn is not None, "Provide a valid loss function" - # TODO: support also Scipy optimizers - assert isinstance(optimizer, NGOptimizer), "Use only optimizers from the Nevergrad library" - - # initialize tracking tool - if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD: - writer = SummaryWriter(config.folder, purge_step=init_iter) - else: - writer = importlib.import_module("mlflow") - - # set optimizer configuration and initial parameters - optimizer.budget = config.max_iter - optimizer.enable_pickling() - - # TODO: Make it GPU compatible if possible - params = get_parameters(model).detach().numpy() - ng_params = ng.p.Array(init=params) - - if not ((dataloader is None) or isinstance(dataloader, (DictDataLoader, DataLoader))): - raise NotImplementedError( - f"Unsupported dataloader type: {type(dataloader)}. " - "You can use e.g. `qadence.ml_tools.to_dataloader` to build a dataloader." - ) - - # serial training - # TODO: Add a parallelization using the num_workers argument in Nevergrad - progress = Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - TimeRemainingColumn(elapsed_when_finished=True), - ) - - # populate callbacks with already available internal functions - # printing, writing and plotting - callbacks = config.callbacks - - # printing - if config.verbose and config.print_every > 0: - callbacks += [ - Callback( - lambda opt_res: print_metrics(opt_res.loss, opt_res.metrics, opt_res.iteration), - called_every=config.print_every, - ) - ] - - # writing metrics - if config.write_every > 0: - callbacks += [ - Callback( - lambda opt_res: write_tracker( - writer, - opt_res.loss, - opt_res.metrics, - opt_res.iteration, - tracking_tool=config.tracking_tool, - ), - called_every=config.write_every, - call_after_opt=True, - ) - ] - - # plot tracker - if config.plot_every > 0: - callbacks += [ - Callback( - lambda opt_res: plot_tracker( - writer, - opt_res.model, - opt_res.iteration, - config.plotting_functions, - tracking_tool=config.tracking_tool, - ), - called_every=config.plot_every, - ) - ] - - # checkpointing - if config.folder and config.checkpoint_every > 0: - callbacks += [ - Callback( - lambda opt_res: write_checkpoint( - config.folder, # type: ignore[arg-type] - opt_res.model, - opt_res.optimizer, - opt_res.iteration, - ), - called_every=config.checkpoint_every, - call_after_opt=True, - ) - ] - - callbacks_end_opt = [ - callback - for callback in callbacks - if callback.call_end_epoch and not callback.call_during_eval - ] - - with progress: - dl_iter = iter(dataloader) if dataloader is not None else None - - for iteration in progress.track(range(init_iter, init_iter + config.max_iter)): - loss, metrics, ng_params = _update_parameters( - None if dataloader is None else next(dl_iter), ng_params # type: ignore[arg-type] - ) - opt_result = OptimizeResult(iteration, model, optimizer, loss, metrics) - run_callbacks(callbacks_end_opt, opt_result) - - if iteration >= init_iter + config.max_iter: - break - - # writing hyperparameters - if config.hyperparams: - log_tracker(writer, config.hyperparams, metrics, tracking_tool=config.tracking_tool) - - if config.log_model: - log_model_tracker(writer, model, dataloader, tracking_tool=config.tracking_tool) - - # Final callbacks - callbacks_after_opt = [callback for callback in callbacks if callback.call_after_opt] - run_callbacks(callbacks_after_opt, opt_result, is_last_iteration=True) - - # close tracker - if config.tracking_tool == ExperimentTrackingTool.TENSORBOARD: - writer.close() - elif config.tracking_tool == ExperimentTrackingTool.MLFLOW: - writer.end_run() - - return model, optimizer diff --git a/qadence/ml_tools/train_utils/__init__.py b/qadence/ml_tools/train_utils/__init__.py new file mode 100644 index 000000000..3069dcb81 --- /dev/null +++ b/qadence/ml_tools/train_utils/__init__.py @@ -0,0 +1,7 @@ +from __future__ import annotations + +from .base_trainer import BaseTrainer +from .config_manager import ConfigManager + +# Modules to be automatically added to the qadence.ml_tools.loss namespace +__all__ = ["BaseTrainer", "ConfigManager"] diff --git a/qadence/ml_tools/train_utils/base_trainer.py b/qadence/ml_tools/train_utils/base_trainer.py new file mode 100644 index 000000000..f1089805c --- /dev/null +++ b/qadence/ml_tools/train_utils/base_trainer.py @@ -0,0 +1,548 @@ +from __future__ import annotations + +from contextlib import contextmanager +from logging import getLogger +from typing import Any, Callable, Iterator + +import nevergrad as ng +import torch +from nevergrad.optimization.base import Optimizer as NGOptimizer +from torch import nn, optim +from torch.utils.data import DataLoader + +from qadence.ml_tools.callbacks import CallbacksManager +from qadence.ml_tools.config import TrainConfig +from qadence.ml_tools.data import InfiniteTensorDataset +from qadence.ml_tools.loss import get_loss_fn +from qadence.ml_tools.optimize_step import optimize_step +from qadence.ml_tools.parameters import get_parameters +from qadence.ml_tools.stages import TrainingStage + +from .config_manager import ConfigManager + +logger = getLogger("ml_tools") + + +class BaseTrainer: + """Base class for training machine learning models using a given optimizer. + + The base class implements contextmanager for gradient based/free optimization, + properties, property setters, input validations, callback decorator generator, + and empty hooks for different training steps. + + This class provides: + - Context managers for enabling/disabling gradient-based optimization + - Properties for managing models, optimizers, and dataloaders + - Input validations and a callback decorator generator + - Config and callback managers using the provided `TrainConfig` + + Attributes: + use_grad (bool): Indicates if gradients are used for optimization. Default is True. + + model (nn.Module): The neural network model. + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. + config (TrainConfig): The configuration settings for training. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. + + optimize_step (Callable): Function for performing an optimization step. + loss_fn (Callable | str ]): loss function to use. Default loss function + used is 'mse' + + num_training_batches (int): Number of training batches. In case of + InfiniteTensorDataset only 1 batch per epoch is used. + num_validation_batches (int): Number of validation batches. In case of + InfiniteTensorDataset only 1 batch per epoch is used. + num_test_batches (int): Number of test batches. In case of + InfiniteTensorDataset only 1 batch per epoch is used. + + state (str): Current state in the training process + """ + + _use_grad: bool = True + + def __init__( + self, + model: nn.Module, + optimizer: optim.Optimizer | NGOptimizer | None, + config: TrainConfig, + loss_fn: str | Callable = "mse", + optimize_step: Callable = optimize_step, + train_dataloader: DataLoader | None = None, + val_dataloader: DataLoader | None = None, + test_dataloader: DataLoader | None = None, + max_batches: int | None = None, + ): + """ + Initializes the BaseTrainer. + + Args: + model (nn.Module): The model to train. + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer + for training. + config (TrainConfig): The TrainConfig settings for training. + loss_fn (str | Callable): The loss function to use. + str input to be specified to use a default loss function. + currently supported loss functions: 'mse', 'cross_entropy'. + If not specified, default mse loss will be used. + train_dataloader (DataLoader | None): DataLoader for training data. + If the model does not need data to evaluate loss, no dataset + should be provided. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. + max_batches (int | None): Maximum number of batches to process per epoch. + This is only valid in case of finite TensorDataset dataloaders. + if max_batches is not None, the maximum number of batches used will + be min(max_batches, len(dataloader.dataset)) + In case of InfiniteTensorDataset only 1 batch per epoch is used. + """ + self._model: nn.Module + self._optimizer: optim.Optimizer | NGOptimizer | None + self._config: TrainConfig + self._train_dataloader: DataLoader | None = None + self._val_dataloader: DataLoader | None = None + self._test_dataloader: DataLoader | None = None + + self.config = config + self.model = model + self.optimizer = optimizer + self.max_batches = max_batches + + self.num_training_batches: int + self.num_validation_batches: int + self.num_test_batches: int + + self.train_dataloader = train_dataloader + self.val_dataloader = val_dataloader + self.test_dataloader = test_dataloader + + self.loss_fn: Callable = get_loss_fn(loss_fn) + self.optimize_step: Callable = optimize_step + self.ng_params: ng.p.Array + self.training_stage: TrainingStage = TrainingStage("idle") + + @property + def use_grad(self) -> bool: + """ + Returns the optimization framework for the trainer. + + use_grad = True : Gradient based optimization + use_grad = False : Gradient free optimization + + Returns: + bool: Bool value for using gradient. + """ + return self._use_grad + + @use_grad.setter + def use_grad(self, use_grad: bool) -> None: + """ + Returns the optimization framework for the trainer. + + use_grad = True : Gradient based optimization + use_grad = False : Gradient free optimization + + Returns: + bool: Bool value for using gradient. + """ + if not isinstance(use_grad, bool): + raise TypeError("use_grad must be an True or False.") + self._use_grad = use_grad + + @classmethod + def set_use_grad(cls, value: bool) -> None: + """ + Sets the global use_grad flag. + + Args: + value (bool): Whether to use gradient-based optimization. + """ + if not isinstance(value, bool): + raise TypeError("use_grad must be a boolean value.") + cls._use_grad = value + + @property + def model(self) -> nn.Module: + """ + Returns the model if set, otherwise raises an error. + + Returns: + nn.Module: The model. + """ + if self._model is None: + raise ValueError("Model has not been set.") + return self._model + + @model.setter + def model(self, model: nn.Module) -> None: + """ + Sets the model, ensuring it is an instance of nn.Module. + + Args: + model (nn.Module): The neural network model. + """ + if model is not None and not isinstance(model, nn.Module): + raise TypeError("model must be an instance of nn.Module or None.") + self._model = model + + @property + def optimizer(self) -> optim.Optimizer | NGOptimizer | None: + """ + Returns the optimizer if set, otherwise raises an error. + + Returns: + optim.Optimizer | NGOptimizer | None: The optimizer. + """ + return self._optimizer + + @optimizer.setter + def optimizer(self, optimizer: optim.Optimizer | NGOptimizer | None) -> None: + """ + Sets the optimizer, checking compatibility with gradient use. + + We also set up the budget/behavior of different optimizers here. + + Args: + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. + """ + if optimizer is not None: + if self.use_grad: + if not isinstance(optimizer, optim.Optimizer): + raise TypeError("use_grad=True requires a PyTorch optimizer instance.") + else: + if not isinstance(optimizer, NGOptimizer): + raise TypeError("use_grad=False requires a Nevergrad optimizer instance.") + else: + optimizer.budget = self.config.max_iter + optimizer.enable_pickling() + params = get_parameters(self.model).detach().numpy() + self.ng_params = ng.p.Array(init=params) + + self._optimizer = optimizer + + @property + def train_dataloader(self) -> DataLoader: + """ + Returns the training DataLoader, validating its type. + + Returns: + DataLoader: The DataLoader for training data. + """ + return self._train_dataloader + + @train_dataloader.setter + def train_dataloader(self, dataloader: DataLoader) -> None: + """ + Sets the training DataLoader and computes the number of batches. + + Args: + dataloader (DataLoader): The DataLoader for training data. + """ + self._validate_dataloader(dataloader, "train") + self._train_dataloader = dataloader + self.num_training_batches = self._compute_num_batches(dataloader) + + @property + def val_dataloader(self) -> DataLoader: + """ + Returns the validation DataLoader, validating its type. + + Returns: + DataLoader: The DataLoader for validation data. + """ + return self._val_dataloader + + @val_dataloader.setter + def val_dataloader(self, dataloader: DataLoader) -> None: + """ + Sets the validation DataLoader and computes the number of batches. + + Args: + dataloader (DataLoader): The DataLoader for validation data. + """ + self._validate_dataloader(dataloader, "val") + self._val_dataloader = dataloader + self.num_validation_batches = self._compute_num_batches(dataloader) + + @property + def test_dataloader(self) -> DataLoader: + """ + Returns the test DataLoader, validating its type. + + Returns: + DataLoader: The DataLoader for testing data. + """ + return self._test_dataloader + + @test_dataloader.setter + def test_dataloader(self, dataloader: DataLoader) -> None: + """ + Sets the test DataLoader and computes the number of batches. + + Args: + dataloader (DataLoader): The DataLoader for testing data. + """ + self._validate_dataloader(dataloader, "test") + self._test_dataloader = dataloader + self.num_test_batches = self._compute_num_batches(dataloader) + + @property + def config(self) -> TrainConfig: + """ + Returns the training configuration. + + Returns: + TrainConfig: The configuration object. + """ + return self._config + + @config.setter + def config(self, value: TrainConfig) -> None: + """ + Sets the training configuration and initializes callback and config managers. + + Args: + value (TrainConfig): The configuration object. + """ + if value and not isinstance(value, TrainConfig): + raise TypeError("config must be an instance of TrainConfig.") + self._config = value + self.callback_manager = CallbacksManager(value) + self.config_manager = ConfigManager(value) + + def _compute_num_batches(self, dataloader: DataLoader) -> int: + """ + Computes the number of batches for the given DataLoader. + + Args: + dataloader (DataLoader): The DataLoader for which to compute + the number of batches. + """ + if dataloader is None: + return 1 + dataset = dataloader.dataset + if isinstance(dataset, InfiniteTensorDataset): + return 1 + else: + n_batches = int( + (dataset.tensors[0].size(0) + dataloader.batch_size - 1) // dataloader.batch_size + ) + return min(self.max_batches, n_batches) if self.max_batches is not None else n_batches + + def _validate_dataloader(self, dataloader: DataLoader, dataloader_type: str) -> None: + """ + Validates the type of the DataLoader and raises errors for unsupported types. + + Args: + dataloader (DataLoader): The DataLoader to validate. + dataloader_type (str): The type of DataLoader ("train", "val", or "test"). + """ + if dataloader is not None: + if not isinstance(dataloader, DataLoader): + raise NotImplementedError( + f"Unsupported dataloader type: {type(dataloader)}." + "The dataloader must be an instance of DataLoader." + ) + if dataloader_type == "val" and self.config.val_every > 0: + if not isinstance(dataloader, DataLoader): + raise ValueError( + "If `config.val_every` is provided as an integer > 0, validation_dataloader" + "must be an instance of `DataLoader`." + ) + + @staticmethod + def callback(phase: str) -> Callable: + """ + Decorator for executing callbacks before and after a phase. + + Phase are different hooks during the training. list of valid + phases is defined in Callbacks. + We also update the current state of the training process in + the callback decorator. + + Args: + phase (str): The phase for which the callback is executed (e.g., "train", + "train_epoch", "train_batch"). + + Returns: + Callable: The decorated function. + """ + + def decorator(method: Callable) -> Callable: + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + start_event = f"{phase}_start" + end_event = f"{phase}_end" + + self.training_stage = TrainingStage(start_event) + self.callback_manager.run_callbacks(trainer=self) + result = method(self, *args, **kwargs) + + self.training_stage = TrainingStage(end_event) + # build_optimize_result method is defined in the trainer. + self.build_optimize_result(result) + self.callback_manager.run_callbacks(trainer=self) + + return result + + return wrapper + + return decorator + + @contextmanager + def enable_grad_opt(self, optimizer: optim.Optimizer | None = None) -> Iterator[None]: + """ + Context manager to temporarily enable gradient-based optimization. + + Args: + optimizer (optim.Optimizer): The PyTorch optimizer to use. + If no optimizer is provided, default optimizer for trainer + object will be used. + """ + original_mode = self.use_grad + original_optimizer = self._optimizer + try: + self.use_grad = True + self.callback_manager.use_grad = True + self.optimizer = optimizer if optimizer else self.optimizer + yield + finally: + self.use_grad = original_mode + self.callback_manager.use_grad = original_mode + self.optimizer = original_optimizer + + @contextmanager + def disable_grad_opt(self, optimizer: NGOptimizer | None = None) -> Iterator[None]: + """ + Context manager to temporarily disable gradient-based optimization. + + Args: + optimizer (NGOptimizer): The Nevergrad optimizer to use. + If no optimizer is provided, default optimizer for trainer + object will be used. + """ + original_mode = self.use_grad + original_optimizer = self._optimizer + try: + self.use_grad = False + self.callback_manager.use_grad = False + self.optimizer = optimizer if optimizer else self.optimizer + yield + finally: + self.use_grad = original_mode + self.callback_manager.use_grad = original_mode + self.optimizer = original_optimizer + + def on_train_start(self) -> None: + """Called at the start of training.""" + pass + + def on_train_end( + self, + train_losses: list[list[tuple[torch.Tensor, Any]]], + val_losses: list[list[tuple[torch.Tensor, Any]]] | None = None, + ) -> None: + """ + Called at the end of training. + + Args: + train_losses (list[list[tuple[torch.Tensor, Any]]]): + Metrics for the training losses. + list -> list -> tuples + Epochs -> Training Batches -> (loss, metrics) + val_losses (list[list[tuple[torch.Tensor, Any]]] | None): + Metrics for the validation losses. + list -> list -> tuples + Epochs -> Validation Batches -> (loss, metrics) + """ + pass + + def on_train_epoch_start(self) -> None: + """Called at the start of each training epoch.""" + pass + + def on_train_epoch_end(self, train_epoch_loss_metrics: list[tuple[torch.Tensor, Any]]) -> None: + """ + Called at the end of each training epoch. + + Args: + train_epoch_loss_metrics: Metrics for the training epoch losses. + list -> tuples + Training Batches -> (loss, metrics) + """ + pass + + def on_val_epoch_start(self) -> None: + """Called at the start of each validation epoch.""" + pass + + def on_val_epoch_end(self, val_epoch_loss_metrics: list[tuple[torch.Tensor, Any]]) -> None: + """ + Called at the end of each validation epoch. + + Args: + val_epoch_loss_metrics: Metrics for the validation epoch loss. + list -> tuples + Validation Batches -> (loss, metrics) + """ + pass + + def on_train_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None: + """ + Called at the start of each training batch. + + Args: + batch: A batch of data from the DataLoader. Typically a tuple containing + input tensors and corresponding target tensors. + """ + pass + + def on_train_batch_end(self, train_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None: + """ + Called at the end of each training batch. + + Args: + train_batch_loss_metrics: Metrics for the training batch loss. + tuple of (loss, metrics) + """ + pass + + def on_val_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None: + """ + Called at the start of each validation batch. + + Args: + batch: A batch of data from the DataLoader. Typically a tuple containing + input tensors and corresponding target tensors. + """ + pass + + def on_val_batch_end(self, val_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None: + """ + Called at the end of each validation batch. + + Args: + val_batch_loss_metrics: Metrics for the validation batch loss. + tuple of (loss, metrics) + """ + pass + + def on_test_batch_start(self, batch: tuple[torch.Tensor, ...] | None) -> None: + """ + Called at the start of each testing batch. + + Args: + batch: A batch of data from the DataLoader. Typically a tuple containing + input tensors and corresponding target tensors. + """ + pass + + def on_test_batch_end(self, test_batch_loss_metrics: tuple[torch.Tensor, Any]) -> None: + """ + Called at the end of each testing batch. + + Args: + test_batch_loss_metrics: Metrics for the testing batch loss. + tuple of (loss, metrics) + """ + pass diff --git a/qadence/ml_tools/train_utils/config_manager.py b/qadence/ml_tools/train_utils/config_manager.py new file mode 100644 index 000000000..084e86c4c --- /dev/null +++ b/qadence/ml_tools/train_utils/config_manager.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import datetime +import os +from logging import getLogger +from pathlib import Path + +from torch import Tensor + +from qadence.ml_tools.config import TrainConfig +from qadence.types import ExperimentTrackingTool + +logger = getLogger("ml_tools") + + +class ConfigManager: + """A class to manage and initialize the configuration for a. + + machine learning training run using TrainConfig. + + Attributes: + config (TrainConfig): The training configuration object + containing parameters and settings. + """ + + optimization_type: str = "with_grad" + + def __init__(self, config: TrainConfig): + """ + Initialize the ConfigManager with a given training configuration. + + Args: + config (TrainConfig): The training configuration object. + """ + self.config: TrainConfig = config + + def initialize_config(self) -> None: + """ + Initialize the configuration by setting up the folder structure,. + + handling hyperparameters, deriving additional parameters, + and logging warnings. + """ + self._initialize_folder() + self._handle_hyperparams() + self._setup_additional_configuration() + self._log_warnings() + + def _initialize_folder(self) -> None: + """ + Initialize the folder structure for logging. + + Creates a log folder + if the folder path is specified in the configuration. + config has three parameters + - folder: The root folder for logging + - subfolders: list of subfolders inside `folder` that are used for logging + - log_folder: folder currently used for logging. + """ + self.config.log_folder = self._createlog_folder(self.config.root_folder) + + def _createlog_folder(self, root_folder: str | Path) -> Path: + """ + Create a log folder in the specified root folder, adding subfolders if required. + + Args: + root_folder (str | Path): The root folder where the log folder will be created. + + Returns: + Path: The path to the created log folder. + """ + self._added_new_subfolder: bool = False + root_folder_path = Path(root_folder) + root_folder_path.mkdir(parents=True, exist_ok=True) + + if self.config.create_subfolder_per_run: + self._add_subfolder() + log_folder = root_folder_path / self.config._subfolders[-1] + else: + if self.config._subfolders: + if self.config.log_folder == root_folder_path / self.config._subfolders[-1]: + log_folder = root_folder_path / self.config._subfolders[-1] + else: + log_folder = Path(self.config.log_folder) + else: + if self.config.log_folder == Path("./"): + self._add_subfolder() + log_folder = root_folder_path / self.config._subfolders[-1] + else: + log_folder = Path(self.config.log_folder) + + log_folder.mkdir(parents=True, exist_ok=True) + return Path(log_folder) + + def _add_subfolder(self) -> None: + """ + Add a unique subfolder name to the configuration for logging. + + The subfolder name includes a run ID, timestamp, and process ID in hexadecimal format. + """ + timestamp = datetime.datetime.now().strftime("%Y%m%dT%H%M%S") + pid_hex = hex(os.getpid())[2:] + run_id = len(self.config._subfolders) + 1 + subfolder_name = f"{run_id}_{timestamp}_{pid_hex}" + self.config._subfolders.append(str(subfolder_name)) + self._added_new_subfolder = True + + def _handle_hyperparams(self) -> None: + """ + Handle and filter hyperparameters based on the selected tracking tool. + + Removes incompatible hyperparameters when using TensorBoard. + """ + # tensorboard only allows for certain types as hyperparameters + + if ( + self.config.hyperparams + and self.config.tracking_tool == ExperimentTrackingTool.TENSORBOARD + ): + self._filter_tb_hyperparams() + + def _filter_tb_hyperparams(self) -> None: + """ + Filter out hyperparameters that cannot be logged by TensorBoard. + + Logs a warning for the removed hyperparameters. + """ + + # tensorboard only allows for certain types as hyperparameters + tb_allowed_hyperparams_types: tuple = (int, float, str, bool, Tensor) + keys_to_remove = [ + key + for key, value in self.config.hyperparams.items() + if not isinstance(value, tb_allowed_hyperparams_types) + ] + if keys_to_remove: + logger.warning( + f"Tensorboard cannot log the following hyperparameters: {keys_to_remove}." + ) + for key in keys_to_remove: + self.config.hyperparams.pop(key) + + def _setup_additional_configuration(self) -> None: + """ + Derive additional parameters for the training configuration. + + Sets the stopping criterion if it is not already defined. + """ + if self.config.trainstop_criterion is None: + self.config.trainstop_criterion = lambda x: x <= self.config.max_iter + + def _log_warnings(self) -> None: + """ + Log warnings for incompatible configurations related to tracking tools. + + and plotting functions. + """ + if ( + self.config.plotting_functions + and self.config.tracking_tool != ExperimentTrackingTool.MLFLOW + ): + logger.warning("In-training plots are only available with mlflow tracking.") + if ( + not self.config.plotting_functions + and self.config.tracking_tool == ExperimentTrackingTool.MLFLOW + ): + logger.warning("Tracking with mlflow, but no plotting functions provided.") + if self.config.plot_every and not self.config.plotting_functions: + logger.warning( + "`plot_every` is only available when `plotting_functions` are provided." + "No plots will be saved." + ) + if self.config.checkpoint_best_only and not self.config.validation_criterion: + logger.warning( + "`Checkpoint_best_only` is only available when `validation_criterion` is provided." + "No checkpoints will be saved." + ) + if self.config.log_folder != Path("./") and self.config.root_folder != Path("./qml_logs"): + logger.warning("Both `log_folder` and `root_folder` provided by the user.") + if self.config.log_folder != Path("./") and self.config.create_subfolder_per_run: + logger.warning( + "`log_folder` is invalid when `create_subfolder_per_run` = True." + "`root_folder` (default qml_logs) will be used to save logs." + ) diff --git a/qadence/ml_tools/trainer.py b/qadence/ml_tools/trainer.py new file mode 100644 index 000000000..d848d53d9 --- /dev/null +++ b/qadence/ml_tools/trainer.py @@ -0,0 +1,692 @@ +from __future__ import annotations + +import copy +from itertools import islice +from logging import getLogger +from typing import Any, Callable, Iterable, cast + +import torch +from nevergrad.optimization.base import Optimizer as NGOptimizer +from rich.progress import BarColumn, Progress, TaskProgressColumn, TextColumn, TimeRemainingColumn +from torch import complex128, float32, float64, nn, optim +from torch import device as torch_device +from torch import dtype as torch_dtype +from torch.utils.data import DataLoader + +from qadence.ml_tools.config import TrainConfig +from qadence.ml_tools.data import OptimizeResult +from qadence.ml_tools.optimize_step import optimize_step, update_ng_parameters +from qadence.ml_tools.stages import TrainingStage + +from .train_utils.base_trainer import BaseTrainer + +logger = getLogger("ml_tools") + + +class Trainer(BaseTrainer): + """Trainer class to manage and execute training, validation, and testing loops for a model (eg. + + QNN). + + This class handles the overall training process, including: + - Managing epochs and steps + - Handling data loading and batching + - Computing and updating gradients + - Logging and monitoring training metrics + + Attributes: + current_epoch (int): The current epoch number. + global_step (int): The global step across all epochs. + log_device (str): Device for logging, default is "cpu". + device (torch_device): Device used for computation. + dtype (torch_dtype | None): Data type used for computation. + data_dtype (torch_dtype | None): Data type for data. + Depends on the model's data type. + + Inherited Attributes: + use_grad (bool): Indicates if gradients are used for optimization. Default is True. + + model (nn.Module): The neural network model. + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. + config (TrainConfig): The configuration settings for training. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for testing data. + + optimize_step (Callable): Function for performing an optimization step. + loss_fn (Callable): loss function to use. + + num_training_batches (int): Number of training batches. + num_validation_batches (int): Number of validation batches. + num_test_batches (int): Number of test batches. + + state (str): Current state in the training process + + Default training routine + ``` + for epoch in max_iter + 1: + # Training + for batch in train_batches: + train model + # Validation + if val_every % epoch == 0: + for batch in val_batches: + train model + ``` + + Notes: + - In case of InfiniteTensorDataset, number of batches = 1. + - In case of TensorDataset, number of batches are default. + - Training is run for max_iter + 1 epochs. Epoch 0 logs untrained model. + - Please look at the CallbackManager initialize_callbacks method to review the default + logging behavior. + + Examples: + + ```python + import torch + from torch.optim import SGD + from qadence import ( + feature_map, + hamiltonian_factory, + hea, + QNN, + QuantumCircuit, + TrainConfig, + Z, + ) + from qadence.ml_tools.trainer import Trainer + from qadence.ml_tools.optimize_step import optimize_step + from qadence.ml_tools import TrainConfig + from qadence.ml_tools.data import to_dataloader + + # Initialize the model + n_qubits = 2 + fm = feature_map(n_qubits) + ansatz = hea(n_qubits=n_qubits, depth=2) + observable = hamiltonian_factory(n_qubits, detuning=Z) + circuit = QuantumCircuit(n_qubits, fm, ansatz) + model = QNN(circuit, observable, backend="pyqtorch", diff_mode="ad") + + # Set up the optimizer + optimizer = SGD(model.parameters(), lr=0.001) + + # Use TrainConfig for configuring the training process + config = TrainConfig( + max_iter=100, + print_every=10, + write_every=10, + checkpoint_every=10, + val_every=10 + ) + + # Create the Trainer instance with TrainConfig + trainer = Trainer( + model=model, + optimizer=optimizer, + config=config, + loss_fn="mse", + optimize_step=optimize_step + ) + + batch_size = 25 + x = torch.linspace(0, 1, 32).reshape(-1, 1) + y = torch.sin(x) + train_loader = to_dataloader(x, y, batch_size=batch_size, infinite=True) + val_loader = to_dataloader(x, y, batch_size=batch_size, infinite=False) + + # Train the model + model, optimizer = trainer.fit(train_loader, val_loader) + ``` + + This also supports both gradient based and gradient free optimization. + The default support is for gradient based optimization. + + Notes: + + - **set_use_grad()** (*class level*):This method is used to set the global `use_grad` flag, + controlling whether the trainer uses gradient-based optimization. + ```python + # gradient based + Trainer.set_use_grad(True) + + # gradient free + Trainer.set_use_grad(False) + ``` + - **Context Managers** (*instance level*): `enable_grad_opt()` and `disable_grad_opt()` are + context managers that temporarily switch the optimization mode for specific code blocks. + This is useful when you want to mix gradient-based and gradient-free optimization + in the same training process. + ```python + # gradient based + with trainer.enable_grad_opt(optimizer): + trainer.fit() + + # gradient free + with trainer.disable_grad_opt(ng_optimizer): + trainer.fit() + ``` + + Examples + + *Gradient based optimization example Usage*: + ```python + from torch import optim + optimizer = optim.SGD(model.parameters(), lr=0.01) + + Trainer.set_use_grad(True) + trainer = Trainer( + model=model, + optimizer=optimizer, + config=config, + loss_fn="mse" + ) + trainer.fit(train_loader, val_loader) + ``` + or + ```python + trainer = Trainer( + model=model, + config=config, + loss_fn="mse" + ) + with trainer.enable_grad_opt(optimizer): + trainer.fit(train_loader, val_loader) + ``` + + *Gradient free optimization example Usage*: + ```python + import nevergrad as ng + from qadence.ml_tools.parameters import num_parameters + ng_optimizer = ng.optimizers.NGOpt( + budget=config.max_iter, parametrization= num_parameters(model) + ) + + Trainer.set_use_grad(False) + trainer = Trainer( + model=model, + optimizer=ng_optimizer, + config=config, + loss_fn="mse" + ) + trainer.fit(train_loader, val_loader) + ``` + or + ```python + import nevergrad as ng + from qadence.ml_tools.parameters import num_parameters + ng_optimizer = ng.optimizers.NGOpt( + budget=config.max_iter, parametrization= num_parameters(model) + ) + + trainer = Trainer( + model=model, + config=config, + loss_fn="mse" + ) + with trainer.disable_grad_opt(ng_optimizer): + trainer.fit(train_loader, val_loader) + ``` + """ + + def __init__( + self, + model: nn.Module, + optimizer: optim.Optimizer | NGOptimizer | None, + config: TrainConfig, + loss_fn: str | Callable = "mse", + train_dataloader: DataLoader | None = None, + val_dataloader: DataLoader | None = None, + test_dataloader: DataLoader | None = None, + optimize_step: Callable = optimize_step, + device: torch_device | None = None, + dtype: torch_dtype | None = None, + max_batches: int | None = None, + ): + """ + Initializes the Trainer class. + + Args: + model (nn.Module): The PyTorch model to train. + optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training. + config (TrainConfig): Training configuration object. + loss_fn (str | Callable ): Loss function used for training. + If not specified, default mse loss will be used. + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + test_dataloader (DataLoader | None): DataLoader for test data. + optimize_step (Callable): Function to execute an optimization step. + device (torch_device): Device to use for computation. + dtype (torch_dtype): Data type for computation. + max_batches (int | None): Maximum number of batches to process per epoch. + This is only valid in case of finite TensorDataset dataloaders. + if max_batches is not None, the maximum number of batches used will + be min(max_batches, len(dataloader.dataset)) + In case of InfiniteTensorDataset only 1 batch per epoch is used. + """ + super().__init__( + model=model, + optimizer=optimizer, + config=config, + loss_fn=loss_fn, + optimize_step=optimize_step, + train_dataloader=train_dataloader, + val_dataloader=val_dataloader, + test_dataloader=test_dataloader, + max_batches=max_batches, + ) + self.current_epoch: int = 0 + self.global_step: int = 0 + self.log_device: str = "cpu" if device is None else device + self.device: torch_device | None = device + self.dtype: torch_dtype | None = dtype + self.data_dtype: torch_dtype | None = None + if self.dtype: + self.data_dtype = float64 if (self.dtype == complex128) else float32 + + def fit( + self, train_dataloader: DataLoader | None = None, val_dataloader: DataLoader | None = None + ) -> tuple[nn.Module, optim.Optimizer]: + """ + Fits the model using the specified training configuration. + + The dataloaders can be provided to train on new datasets, or the default dataloaders + provided in the trainer will be used. + + Args: + train_dataloader (DataLoader | None): DataLoader for training data. + val_dataloader (DataLoader | None): DataLoader for validation data. + + Returns: + tuple[nn.Module, optim.Optimizer]: The trained model and optimizer. + """ + if train_dataloader is not None: + self.train_dataloader = train_dataloader + if val_dataloader is not None: + self.val_dataloader = val_dataloader + + self._fit_setup() + self._train() + self._fit_end() + self.training_stage = TrainingStage("idle") + return self.model, self.optimizer + + def _fit_setup(self) -> None: + """ + Sets up the training environment, initializes configurations,. + + and moves the model to the specified device and data type. + The callback_manager.start_training takes care of loading checkpoint, + and setting up the writer. + """ + self.config_manager.initialize_config() + self.callback_manager.start_training(trainer=self) + + # Move model to device + if isinstance(self.model, nn.DataParallel): + self.model = self.model.module.to(device=self.device, dtype=self.dtype) + else: + self.model = self.model.to(device=self.device, dtype=self.dtype) + + # Progress bar for training visualization + self.progress: Progress = Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + TaskProgressColumn(), + TimeRemainingColumn(elapsed_when_finished=True), + ) + + # Quick Fix for build_optimize_step + # Please review run_train_batch for more details + self.model_old = copy.deepcopy(self.model) + self.optimizer_old = copy.deepcopy(self.optimizer) + + # Run validation at the start if specified in the configuration + self.perform_val = self.config.val_every > 0 + if self.perform_val: + self.run_validation(self.val_dataloader) + + def _fit_end(self) -> None: + """Finalizes the training and closes the writer.""" + self.callback_manager.end_training(trainer=self) + + @BaseTrainer.callback("train") + def _train(self) -> list[list[tuple[torch.Tensor, dict[str, Any]]]]: + """ + Runs the main training loop, iterating over epochs. + + Returns: + list[list[tuple[torch.Tensor, dict[str, Any]]]]: Training loss + metrics for all epochs. + list -> list -> tuples + Epochs -> Training Batches -> (loss, metrics) + """ + self.on_train_start() + train_losses = [] + val_losses = [] + + with self.progress: + train_task = self.progress.add_task( + "Training", total=self.config_manager.config.max_iter + ) + if self.perform_val: + val_task = self.progress.add_task( + "Validation", + total=(self.config_manager.config.max_iter + 1) / self.config.val_every, + ) + for epoch in range( + self.global_step, self.global_step + self.config_manager.config.max_iter + 1 + ): + try: + self.current_epoch = epoch + self.on_train_epoch_start() + train_epoch_loss_metrics = self.run_training(self.train_dataloader) + train_losses.append(train_epoch_loss_metrics) + self.on_train_epoch_end(train_epoch_loss_metrics) + + # Run validation periodically if specified + if self.perform_val and self.current_epoch % self.config.val_every == 0: + self.on_val_epoch_start() + val_epoch_loss_metrics = self.run_validation(self.val_dataloader) + val_losses.append(val_epoch_loss_metrics) + self.on_val_epoch_end(val_epoch_loss_metrics) + self.progress.update(val_task, advance=1) + + self.progress.update(train_task, advance=1) + except KeyboardInterrupt: + logger.info("Terminating training gracefully after the current iteration.") + break + + self.on_train_end(train_losses, val_losses) + return train_losses + + @BaseTrainer.callback("train_epoch") + def run_training(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]: + """ + Runs the training for a single epoch, iterating over multiple batches. + + Args: + dataloader (DataLoader): DataLoader for training data. + + Returns: + list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch. + list -> tuples + Training Batches -> (loss, metrics) + """ + self.model.train() + train_epoch_loss_metrics = [] + # Deep copy model and optimizer to maintain checkpoints + # We do this because optimize step provides loss, metrics + # before step of optimization + # To align them with model/optimizer correctly, we checkpoint + # the older copy of the model. + # TODO: review optimize_step to provide iteration aligned model and loss. + self.model_old = copy.deepcopy(self.model) + self.optimizer_old = copy.deepcopy(self.optimizer) + + for batch in self.batch_iter(dataloader, self.num_training_batches): + self.on_train_batch_start(batch) + train_batch_loss_metrics = self.run_train_batch(batch) + train_epoch_loss_metrics.append(train_batch_loss_metrics) + self.on_train_batch_end(train_batch_loss_metrics) + + return train_epoch_loss_metrics + + @BaseTrainer.callback("train_batch") + def run_train_batch( + self, batch: tuple[torch.Tensor, ...] + ) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Runs a single training batch, performing optimization. + + We use the step function to optimize the model based on use_grad. + use_grad = True entails gradient based optimization, for which we use + optimize_step function. + use_grad = False entails gradient free optimization, for which we use + update_ng_parameters function. + + Args: + batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader. + + Returns: + tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch. + tuple of (loss, metrics) + """ + + if self.use_grad: + # Perform gradient-based optimization + loss_metrics = self.optimize_step( + model=self.model, + optimizer=self.optimizer, + loss_fn=self.loss_fn, + xs=batch, + device=self.device, + dtype=self.data_dtype, + ) + else: + # Perform optimization using Nevergrad + loss, metrics, ng_params = update_ng_parameters( + model=self.model, + optimizer=self.optimizer, + loss_fn=self.loss_fn, + data=batch, + ng_params=self.ng_params, # type: ignore[arg-type] + ) + self.ng_params = ng_params + loss_metrics = loss, metrics + + return self.modify_batch_end_loss_metrics(loss_metrics) + + @BaseTrainer.callback("val_epoch") + def run_validation(self, dataloader: DataLoader) -> list[tuple[torch.Tensor, dict[str, Any]]]: + """ + Runs the validation loop for a single epoch, iterating over multiple batches. + + Args: + dataloader (DataLoader): DataLoader for validation data. + + Returns: + list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch. + list -> tuples + Validation Batches -> (loss, metrics) + """ + self.model.eval() + val_epoch_loss_metrics = [] + + for batch in self.batch_iter(dataloader, self.num_validation_batches): + self.on_val_batch_start(batch) + val_batch_loss_metrics = self.run_val_batch(batch) + val_epoch_loss_metrics.append(val_batch_loss_metrics) + self.on_val_batch_end(val_batch_loss_metrics) + + return val_epoch_loss_metrics + + @BaseTrainer.callback("val_batch") + def run_val_batch(self, batch: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Runs a single validation batch. + + Args: + batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader. + + Returns: + tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch. + """ + with torch.no_grad(): + loss_metrics = self.loss_fn(self.model, batch) + return self.modify_batch_end_loss_metrics(loss_metrics) + + def test(self, test_dataloader: DataLoader = None) -> list[tuple[torch.Tensor, dict[str, Any]]]: + """ + Runs the testing loop if a test DataLoader is provided. + + if the test_dataloader is not provided, default test_dataloader defined + in the Trainer class is used. + + Args: + test_dataloader (DataLoader): DataLoader for test data. + + Returns: + list[tuple[torch.Tensor, dict[str, Any]]]: Loss and metrics for each batch. + list -> tuples + Test Batches -> (loss, metrics) + """ + if test_dataloader is not None: + self.test_dataloader = test_dataloader + + self.model.eval() + test_loss_metrics = [] + + for batch in self.batch_iter(test_dataloader, self.num_training_batches): + self.on_test_batch_start(batch) + loss_metrics = self.run_test_batch(batch) + test_loss_metrics.append(loss_metrics) + self.on_test_batch_end(loss_metrics) + + return test_loss_metrics + + @BaseTrainer.callback("test_batch") + def run_test_batch( + self, batch: tuple[torch.Tensor, ...] + ) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Runs a single test batch. + + Args: + batch (tuple[torch.Tensor, ...]): Batch of data from the DataLoader. + + Returns: + tuple[torch.Tensor, dict[str, Any]]: Loss and metrics for the batch. + """ + with torch.no_grad(): + loss_metrics = self.loss_fn(self.model, batch) + return self.modify_batch_end_loss_metrics(loss_metrics) + + def batch_iter( + self, + dataloader: DataLoader, + num_batches: int, + ) -> Iterable[tuple[torch.Tensor, ...] | None]: + """ + Yields batches from the provided dataloader. + + Args: + dataloader ([DataLoader]): The dataloader to iterate over. + num_batches (int): The maximum number of batches to yield. + + Yields: + Iterable[tuple[torch.Tensor, ...] | None]: A batch from the dataloader moved to the + specified device and dtype. + """ + if dataloader is None: + for _ in range(num_batches): + yield None + else: + for batch in islice(dataloader, num_batches): + # batch is moved to device inside optimize step + # batch = data_to_device(batch, device=self.device, dtype=self.data_dtype) + yield batch + + def modify_batch_end_loss_metrics( + self, loss_metrics: tuple[torch.Tensor, dict[str, Any]] + ) -> tuple[torch.Tensor, dict[str, Any]]: + """ + Modifies the loss and metrics at the end of batch for proper logging. + + All metrics are prefixed with the proper state of the training process + - "train_" or "val_" or "test_" + A "{state}_loss" is added to metrics. + + Args: + loss_metrics (tuple[torch.Tensor, dict[str, Any]]): Original loss and metrics. + + Returns: + tuple[None | torch.Tensor, dict[str, Any]]: Modified loss and metrics. + """ + for phase in ["train", "val", "test"]: + if phase in self.training_stage: + loss, metrics = loss_metrics + updated_metrics = {f"{phase}_{key}": value for key, value in metrics.items()} + updated_metrics[f"{phase}_loss"] = loss + return loss, updated_metrics + return loss_metrics + + def build_optimize_result( + self, + result: None + | tuple[torch.Tensor, dict[Any, Any]] + | list[tuple[torch.Tensor, dict[Any, Any]]] + | list[list[tuple[torch.Tensor, dict[Any, Any]]]], + ) -> None: + """ + Builds and stores the optimization result by calculating the average loss and metrics. + + Result (or loss_metrics) can have multiple formats: + - `None` Indicates no loss or metrics data is provided. + - `tuple[torch.Tensor, dict[str, Any]]` A single tuple containing the loss tensor + and metrics dictionary - at the end of batch. + - `list[tuple[torch.Tensor, dict[str, Any]]]` A list of tuples for + multiple batches. + - `list[list[tuple[torch.Tensor, dict[str, Any]]]]` A list of lists of tuples, + where each inner list represents metrics across multiple batches within an epoch. + + Args: + result: (None | + tuple[torch.Tensor, dict[Any, Any]] | + list[tuple[torch.Tensor, dict[Any, Any]]] | + list[list[tuple[torch.Tensor, dict[Any, Any]]]]) + The loss and metrics data, which can have multiple formats + + Returns: + None: This method does not return anything. It sets `self.opt_result` with + the computed average loss and metrics. + """ + loss_metrics = result + if loss_metrics is None: + loss = None + metrics: dict[Any, Any] = {} + elif isinstance(loss_metrics, tuple): + # Single tuple case + loss, metrics = loss_metrics + else: + last_epoch: list[tuple[torch.Tensor, dict[Any, Any]]] = [] + if isinstance(loss_metrics, list): + # Check if it's a list of tuples + if all(isinstance(item, tuple) for item in loss_metrics): + last_epoch = cast(list[tuple[torch.Tensor, dict[Any, Any]]], loss_metrics) + # Check if it's a list of lists of tuples + elif all(isinstance(item, list) for item in loss_metrics): + last_epoch = cast( + list[tuple[torch.Tensor, dict[Any, Any]]], + loss_metrics[-1] if loss_metrics else [], + ) + else: + raise ValueError( + "Invalid format for result: Expected None, tuple, list of tuples," + " or list of lists of tuples." + ) + + if not last_epoch: + loss, metrics = None, {} + else: + # Compute the average loss over the batches + loss_tensor = torch.stack([loss_batch for loss_batch, _ in last_epoch]) + avg_loss = loss_tensor.mean() + + # Collect and average metrics for all batches + metric_keys = last_epoch[0][1].keys() + metrics_stacked: dict = {key: [] for key in metric_keys} + + for _, metrics_batch in last_epoch: + for key in metric_keys: + value = metrics_batch[key] + metrics_stacked[key].append(value) + + avg_metrics = {key: torch.stack(metrics_stacked[key]).mean() for key in metric_keys} + + loss, metrics = avg_loss, avg_metrics + + # Store the optimization result + self.opt_result = OptimizeResult( + self.current_epoch, self.model_old, self.optimizer_old, loss, metrics + ) diff --git a/qadence/model.py b/qadence/model.py index fdc311f88..80d3fc404 100644 --- a/qadence/model.py +++ b/qadence/model.py @@ -514,7 +514,7 @@ def load( if isinstance(file_path, str): file_path = Path(file_path) if os.path.isdir(file_path): - from qadence.ml_tools.saveload import get_latest_checkpoint_name + from qadence.ml_tools.callbacks.saveload import get_latest_checkpoint_name file_path = file_path / get_latest_checkpoint_name(file_path, "model") diff --git a/tests/ml_tools/test_callbacks.py b/tests/ml_tools/test_callbacks.py new file mode 100644 index 000000000..4cb6911d4 --- /dev/null +++ b/tests/ml_tools/test_callbacks.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +from pathlib import Path +from unittest.mock import Mock + +import pytest +import torch +from torch.utils.data import DataLoader + +from qadence.ml_tools import TrainConfig, Trainer +from qadence.ml_tools.callbacks import ( + LoadCheckpoint, + LogHyperparameters, + LogModelTracker, + PlotMetrics, + PrintMetrics, + SaveBestCheckpoint, + SaveCheckpoint, + WriteMetrics, +) +from qadence.ml_tools.callbacks.saveload import write_checkpoint +from qadence.ml_tools.data import OptimizeResult, to_dataloader +from qadence.ml_tools.stages import TrainingStage + + +def dataloader(batch_size: int = 25) -> DataLoader: + x = torch.linspace(0, 1, batch_size).reshape(-1, 1) + y = torch.cos(x) + return to_dataloader(x, y, batch_size=batch_size, infinite=True) + + +@pytest.fixture +def trainer(Basic: torch.nn.Module, tmp_path: Path) -> Trainer: + """Set up a real Trainer with a Basic and optimizer.""" + data = dataloader() + model = Basic + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + config = TrainConfig( + log_folder=tmp_path, + max_iter=1, + checkpoint_best_only=True, + validation_criterion=lambda loss, best, ep: loss < (best - ep), + val_epsilon=1e-5, + ) + trainer = Trainer( + model=model, optimizer=optimizer, config=config, loss_fn="mse", train_dataloader=data + ) + trainer.opt_result = OptimizeResult( + iteration=1, + model=model, + optimizer=optimizer, + loss=torch.tensor(0.5), + metrics={"accuracy": torch.tensor(0.8)}, + ) + trainer.training_stage = TrainingStage("train_start") + return trainer + + +def test_save_checkpoint(trainer: Trainer) -> None: + writer = trainer.callback_manager.writer = Mock() + stage = trainer.training_stage + callback = SaveCheckpoint(stage, called_every=1) + callback(stage, trainer, trainer.config, writer) + + checkpoint_file = ( + trainer.config.log_folder / f"model_{type(trainer.model).__name__}_ckpt_001_device_cpu.pt" + ) + assert checkpoint_file.exists() + + +def test_save_best_checkpoint(trainer: Trainer) -> None: + writer = trainer.callback_manager.writer = Mock() + stage = trainer.training_stage + callback = SaveBestCheckpoint(on=stage, called_every=1) + callback(stage, trainer, trainer.config, writer) + + best_checkpoint_file = ( + trainer.config.log_folder / f"model_{type(trainer.model).__name__}_ckpt_best_device_cpu.pt" + ) + assert best_checkpoint_file.exists() + assert callback.best_loss == trainer.opt_result.loss + + +def test_print_metrics(trainer: Trainer) -> None: + writer = trainer.callback_manager.writer = Mock() + stage = trainer.training_stage + callback = PrintMetrics(on=stage, called_every=1) + callback(stage, trainer, trainer.config, writer) + writer.print_metrics.assert_called_once_with(trainer.opt_result) + + +def test_write_metrics(trainer: Trainer) -> None: + writer = trainer.callback_manager.writer = Mock() + stage = trainer.training_stage + callback = WriteMetrics(on=stage, called_every=1) + callback(stage, trainer, trainer.config, writer) + writer.write.assert_called_once_with(trainer.opt_result) + + +def test_plot_metrics(trainer: Trainer) -> None: + trainer.config.plotting_functions = (lambda model, iteration: ("plot_name", None),) + writer = trainer.callback_manager.writer = Mock() + stage = trainer.training_stage + callback = PlotMetrics(stage, called_every=1) + callback(stage, trainer, trainer.config, writer) + + writer.plot.assert_called_once_with( + trainer.model, + trainer.opt_result.iteration, + trainer.config.plotting_functions, + ) + + +def test_log_hyperparameters(trainer: Trainer) -> None: + writer = trainer.callback_manager.writer = Mock() + stage = trainer.training_stage + trainer.config.hyperparams = {"learning_rate": 0.01, "epochs": 10} + callback = LogHyperparameters(stage, called_every=1) + callback(stage, trainer, trainer.config, writer) + writer.log_hyperparams.assert_called_once_with(trainer.config.hyperparams) + + +def test_load_checkpoint(trainer: Trainer) -> None: + # Prepare a checkpoint + write_checkpoint(trainer.config.log_folder, trainer.model, trainer.optimizer, iteration=1) + writer = trainer.callback_manager.writer = Mock() + stage = trainer.training_stage + callback = LoadCheckpoint(stage, called_every=1) + model, optimizer, iteration = callback(stage, trainer, trainer.config, writer) + + assert model is not None + assert optimizer is not None + assert iteration == 1 + + +def test_log_model_tracker(trainer: Trainer) -> None: + writer = trainer.callback_manager.writer = Mock() + callback = LogModelTracker(on=trainer.training_stage, called_every=1) + callback(trainer.training_stage, trainer, trainer.config, writer) + writer.log_model.assert_called_once_with( + trainer.model, + trainer.train_dataloader, + trainer.val_dataloader, + trainer.test_dataloader, + ) diff --git a/tests/ml_tools/test_checkpointing.py b/tests/ml_tools/test_checkpointing.py index 8877e52c5..6bb69c974 100644 --- a/tests/ml_tools/test_checkpointing.py +++ b/tests/ml_tools/test_checkpointing.py @@ -12,8 +12,8 @@ from qadence import QNN, QuantumModel from qadence.ml_tools import ( TrainConfig, + Trainer, load_checkpoint, - train_with_grad, ) from qadence.ml_tools.data import to_dataloader from qadence.ml_tools.parameters import get_parameters, set_parameters @@ -63,13 +63,15 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d loss = criterion(out, y) return loss, {} - config = TrainConfig(folder=tmp_path, max_iter=1, checkpoint_every=1, write_every=1) - model, _ = train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + config = TrainConfig(root_folder=tmp_path, max_iter=1, checkpoint_every=1, write_every=1) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + model, _ = trainer.fit() ps0 = get_parameters(model) set_parameters(model, torch.ones(len(get_parameters(model)))) # write_checkpoint(tmp_path, model, optimizer, 1) # check that saved model has ones - model, _, _ = load_checkpoint(tmp_path, model, optimizer) + model, _, _ = load_checkpoint(trainer.config.log_folder, model, optimizer) ps1 = get_parameters(model) assert torch.allclose(ps0, ps1) @@ -87,14 +89,16 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict loss = criterion(out, torch.rand(1)) return loss, {} - config = TrainConfig(folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) - model, optimizer = train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + config = TrainConfig(root_folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + model, optimizer = trainer.fit() ps0 = get_parameters(model) ev0 = model.expectation({}) # Modify model's parameters set_parameters(model, torch.ones(len(ps0))) - model, optimizer, _ = load_checkpoint(tmp_path, model, optimizer) + model, optimizer, _ = load_checkpoint(trainer.config.log_folder, model, optimizer) ps1 = get_parameters(model) ev1 = model.expectation({}) @@ -103,7 +107,7 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict assert torch.allclose(ev0, ev1) loaded_model, optimizer, _ = load_checkpoint( - tmp_path, + trainer.config.log_folder, BasicQuantumModel, optimizer, "model_QuantumModel_ckpt_009_device_cpu.pt", @@ -112,6 +116,72 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict assert not torch.all(torch.isnan(loaded_model.expectation({}))) +def test_create_subfolders_perrun(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None: + data = dataloader() + model = BasicQuantumModel + cnt = count() + criterion = torch.nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict]: + next(cnt) + out = model.expectation({}).squeeze(dim=0) + loss = criterion(out, torch.rand(1)) + return loss, {} + + config = TrainConfig(root_folder=tmp_path, max_iter=10, create_subfolder_per_run=False) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() + with trainer.enable_grad_opt(): + trainer.fit() + + assert os.path.isdir(tmp_path) + assert len(os.listdir(tmp_path)) == 1 + + config = TrainConfig(root_folder=tmp_path, max_iter=10, create_subfolder_per_run=True) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() + with trainer.enable_grad_opt(): + trainer.fit() + + assert os.path.isdir(tmp_path) + assert len(os.listdir(tmp_path)) == 3 + + +def test_log_folder_logging(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None: + data = dataloader() + model = BasicQuantumModel + cnt = count() + criterion = torch.nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + + def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict]: + next(cnt) + out = model.expectation({}).squeeze(dim=0) + loss = criterion(out, torch.rand(1)) + return loss, {} + + config = TrainConfig(log_folder=tmp_path, max_iter=10, checkpoint_every=1) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() + + assert os.path.isdir(tmp_path) + + ckpts = [ + trainer.config.log_folder / Path(f"model_QuantumModel_ckpt_00{i}_device_cpu.pt") + for i in range(1, 9) + ] + assert all(os.path.isfile(ckpt) for ckpt in ckpts) + for ckpt in ckpts: + loaded_model, optimizer, _ = load_checkpoint( + tmp_path, BasicQuantumModel, optimizer, ckpt, "" + ) + assert torch.allclose(loaded_model.expectation({}), model.expectation({})) + + def test_check_ckpts_exist(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None: data = dataloader() model = BasicQuantumModel @@ -125,9 +195,14 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict loss = criterion(out, torch.rand(1)) return loss, {} - config = TrainConfig(folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) - ckpts = [tmp_path / Path(f"model_QuantumModel_ckpt_00{i}_device_cpu.pt") for i in range(1, 9)] + config = TrainConfig(root_folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() + ckpts = [ + trainer.config.log_folder / Path(f"model_QuantumModel_ckpt_00{i}_device_cpu.pt") + for i in range(1, 9) + ] assert all(os.path.isfile(ckpt) for ckpt in ckpts) for ckpt in ckpts: loaded_model, optimizer, _ = load_checkpoint( @@ -150,15 +225,17 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict loss = criterion(out, torch.rand(1)) return loss, {} - config = TrainConfig(folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + config = TrainConfig(root_folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() ps0 = get_parameters(model) ev0 = model.expectation(inputs) # Modify model's parameters set_parameters(model, torch.ones(len(ps0))) - model, optimizer, _ = load_checkpoint(tmp_path, model, optimizer) + model, optimizer, _ = load_checkpoint(trainer.config.log_folder, model, optimizer) ps1 = get_parameters(model) ev1 = model.expectation(inputs) @@ -190,9 +267,13 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict loss = criterion(out, torch.rand(1)) return loss, {} - config = TrainConfig(folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) - ckpts = [tmp_path / Path(f"model_QNN_ckpt_00{i}_device_cpu.pt") for i in range(1, 9)] + config = TrainConfig(root_folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() + ckpts = [ + trainer.config.log_folder / Path(f"model_QNN_ckpt_00{i}_device_cpu.pt") for i in range(1, 9) + ] assert all(os.path.isfile(ckpt) for ckpt in ckpts) for ckpt in ckpts: loaded_model, optimizer, _ = load_checkpoint(tmp_path, BasicQNN, optimizer, ckpt, "") diff --git a/tests/ml_tools/test_logging.py b/tests/ml_tools/test_logging.py index 640dacd8c..7f24b6361 100644 --- a/tests/ml_tools/test_logging.py +++ b/tests/ml_tools/test_logging.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import os import shutil from itertools import count @@ -18,7 +19,8 @@ from torch.optim import Optimizer from torch.utils.data import DataLoader -from qadence.ml_tools import TrainConfig, train_with_grad +from qadence.ml_tools import TrainConfig, Trainer +from qadence.ml_tools.callbacks.writer_registry import BaseWriter from qadence.ml_tools.data import to_dataloader from qadence.ml_tools.models import QNN from qadence.ml_tools.utils import rand_featureparameters @@ -47,8 +49,8 @@ def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict return loss_fn, optimizer -def load_mlflow_model(train_config: TrainConfig) -> None: - run_id = train_config.mlflow_config.run.info.run_id +def load_mlflow_model(writer: BaseWriter) -> None: + run_id = writer.run.info.run_id mlflow.pytorch.load_model(model_uri=f"runs:/{run_id}/model") @@ -59,8 +61,8 @@ def find_mlflow_artifacts_path(run: Run) -> Path: return Path(os.path.abspath(os.path.join(parsed_uri.netloc, parsed_uri.path))) -def clean_mlflow_experiment(train_config: TrainConfig) -> None: - experiment_id = train_config.mlflow_config.run.info.experiment_id +def clean_mlflow_experiment(writer: BaseWriter) -> None: + experiment_id = writer.run.info.experiment_id client = MlflowClient() runs = client.search_runs(experiment_id) @@ -80,6 +82,15 @@ def clean_artifacts(run: Run) -> None: shutil.rmtree(os.path.join(mlruns_base_dir, experiment_id)) +def setup_logger() -> logging.Logger: + logger = logging.getLogger("ml_tools") + # an additional streamhandler is needed in ml_tools as + # caplog does not record richhandler logs. + stream_handler = logging.StreamHandler() + logger.addHandler(stream_handler) + return logger + + def test_hyperparams_logging_mlflow(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None: model = BasicQuantumModel @@ -88,7 +99,7 @@ def test_hyperparams_logging_mlflow(BasicQuantumModel: QuantumModel, tmp_path: P hyperparams = {"max_iter": int(10), "lr": 0.1} config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=hyperparams["max_iter"], # type: ignore checkpoint_every=1, write_every=1, @@ -96,18 +107,20 @@ def test_hyperparams_logging_mlflow(BasicQuantumModel: QuantumModel, tmp_path: P tracking_tool=ExperimentTrackingTool.MLFLOW, ) - train_with_grad(model, None, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, None) + with trainer.enable_grad_opt(): + trainer.fit() - mlflow_config = config.mlflow_config - experiment_id = mlflow_config.run.info.experiment_id - run_id = mlflow_config.run.info.run_id + writer = trainer.callback_manager.writer + experiment_id = writer.run.info.experiment_id + run_id = writer.run.info.run_id experiment_dir = Path(f"mlruns/{experiment_id}") hyperparams_files = [experiment_dir / run_id / "params" / key for key in hyperparams.keys()] assert all([os.path.isfile(hf) for hf in hyperparams_files]) - clean_mlflow_experiment(config) + clean_mlflow_experiment(trainer.callback_manager.writer) def test_hyperparams_logging_tensorboard(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None: @@ -118,7 +131,7 @@ def test_hyperparams_logging_tensorboard(BasicQuantumModel: QuantumModel, tmp_pa hyperparams = {"max_iter": int(10), "lr": 0.1} config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=hyperparams["max_iter"], # type: ignore checkpoint_every=1, write_every=1, @@ -126,7 +139,9 @@ def test_hyperparams_logging_tensorboard(BasicQuantumModel: QuantumModel, tmp_pa tracking_tool=ExperimentTrackingTool.TENSORBOARD, ) - train_with_grad(model, None, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, None) + with trainer.enable_grad_opt(): + trainer.fit() def test_model_logging_mlflow_basicQM(BasicQuantumModel: QuantumModel, tmp_path: Path) -> None: @@ -134,7 +149,7 @@ def test_model_logging_mlflow_basicQM(BasicQuantumModel: QuantumModel, tmp_path: loss_fn, optimizer = setup_model(model) config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=10, # type: ignore checkpoint_every=1, write_every=1, @@ -142,11 +157,13 @@ def test_model_logging_mlflow_basicQM(BasicQuantumModel: QuantumModel, tmp_path: tracking_tool=ExperimentTrackingTool.MLFLOW, ) - train_with_grad(model, None, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, None) + with trainer.enable_grad_opt(): + trainer.fit() - load_mlflow_model(config) + load_mlflow_model(trainer.callback_manager.writer) - clean_mlflow_experiment(config) + clean_mlflow_experiment(trainer.callback_manager.writer) def test_model_logging_mlflow_basicQNN(BasicQNN: QNN, tmp_path: Path) -> None: @@ -156,7 +173,7 @@ def test_model_logging_mlflow_basicQNN(BasicQNN: QNN, tmp_path: Path) -> None: loss_fn, optimizer = setup_model(model) config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=10, # type: ignore checkpoint_every=1, write_every=1, @@ -164,11 +181,13 @@ def test_model_logging_mlflow_basicQNN(BasicQNN: QNN, tmp_path: Path) -> None: tracking_tool=ExperimentTrackingTool.MLFLOW, ) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() - load_mlflow_model(config) + load_mlflow_model(trainer.callback_manager.writer) - clean_mlflow_experiment(config) + clean_mlflow_experiment(trainer.callback_manager.writer) def test_model_logging_mlflow_basicAdjQNN(BasicAdjointQNN: QNN, tmp_path: Path) -> None: @@ -178,7 +197,7 @@ def test_model_logging_mlflow_basicAdjQNN(BasicAdjointQNN: QNN, tmp_path: Path) loss_fn, optimizer = setup_model(model) config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=10, # type: ignore checkpoint_every=1, write_every=1, @@ -186,22 +205,25 @@ def test_model_logging_mlflow_basicAdjQNN(BasicAdjointQNN: QNN, tmp_path: Path) tracking_tool=ExperimentTrackingTool.MLFLOW, ) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() - load_mlflow_model(config) + load_mlflow_model(trainer.callback_manager.writer) - clean_mlflow_experiment(config) + clean_mlflow_experiment(trainer.callback_manager.writer) def test_model_logging_tensorboard( - BasicQuantumModel: QuantumModel, tmp_path: Path, caplog: pytest.LogCaptureFixture + BasicQuantumModel: QuantumModel, tmp_path: Path, capsys: pytest.LogCaptureFixture ) -> None: + setup_logger() model = BasicQuantumModel loss_fn, optimizer = setup_model(model) config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=10, # type: ignore checkpoint_every=1, write_every=1, @@ -209,9 +231,12 @@ def test_model_logging_tensorboard( tracking_tool=ExperimentTrackingTool.TENSORBOARD, ) - train_with_grad(model, None, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, None) + with trainer.enable_grad_opt(): + trainer.fit() - assert "Model logging is not supported by tensorboard. No model will be logged." in caplog.text + captured = capsys.readouterr() + assert "Model logging is not supported by tensorboard. No model will be logged." in captured.err def test_plotting_mlflow(BasicQNN: QNN, tmp_path: Path) -> None: @@ -241,7 +266,7 @@ def plot_error(model: QuantumModel, iteration: int) -> tuple[str, Figure]: max_iter = 10 plot_every = 2 config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=max_iter, checkpoint_every=1, write_every=1, @@ -250,16 +275,18 @@ def plot_error(model: QuantumModel, iteration: int) -> tuple[str, Figure]: plotting_functions=(plot_model, plot_error), ) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() all_plot_names = [f"model_prediction_epoch_{i}.png" for i in range(0, max_iter, plot_every)] all_plot_names.extend([f"error_epoch_{i}.png" for i in range(0, max_iter, plot_every)]) - artifact_path = find_mlflow_artifacts_path(config.mlflow_config.run) + artifact_path = find_mlflow_artifacts_path(trainer.callback_manager.writer.run) assert all([os.path.isfile(artifact_path / pn) for pn in all_plot_names]) - clean_mlflow_experiment(config) + clean_mlflow_experiment(trainer.callback_manager.writer) def test_plotting_tensorboard(BasicQNN: QNN, tmp_path: Path) -> None: @@ -287,7 +314,7 @@ def plot_error(model: QuantumModel, iteration: int) -> tuple[str, Figure]: return descr, fig config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=10, checkpoint_every=1, write_every=1, @@ -295,4 +322,6 @@ def plot_error(model: QuantumModel, iteration: int) -> tuple[str, Figure]: plotting_functions=(plot_model, plot_error), ) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() diff --git a/tests/ml_tools/test_train.py b/tests/ml_tools/test_train.py index 6d6a21eef..212c5c205 100644 --- a/tests/ml_tools/test_train.py +++ b/tests/ml_tools/test_train.py @@ -10,7 +10,7 @@ import torch from torch.utils.data import DataLoader -from qadence.ml_tools import QNN, DictDataLoader, TrainConfig, to_dataloader, train_with_grad +from qadence.ml_tools import QNN, DictDataLoader, TrainConfig, Trainer, to_dataloader torch.manual_seed(42) np.random.seed(42) @@ -22,14 +22,12 @@ def dataloader(batch_size: int = 25) -> DataLoader: return to_dataloader(x, y, batch_size=batch_size, infinite=True) -def dictdataloader(batch_size: int = 25, val: bool = False) -> DictDataLoader: +def train_val_dataloaders(batch_size: int = 25) -> tuple: x = torch.rand(batch_size, 1) y = torch.sin(x) - dls = { - "train" if val else "y1": to_dataloader(x, y, batch_size=batch_size, infinite=True), - "val" if val else "y2": to_dataloader(x, y, batch_size=batch_size, infinite=True), - } - return DictDataLoader(dls) + train_dataloader = to_dataloader(x, y, batch_size=batch_size, infinite=True) + val_dataloader = to_dataloader(x, y, batch_size=batch_size, infinite=True) + return train_dataloader, val_dataloader def validation_criterion( @@ -42,7 +40,7 @@ def get_train_config_validation( tmp_path: Path, n_epochs: int, checkpoint_every: int, val_every: int ) -> TrainConfig: config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=n_epochs, print_every=10, checkpoint_every=checkpoint_every, @@ -73,14 +71,18 @@ def test_train_dataloader_default(tmp_path: Path, Basic: torch.nn.Module) -> Non def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: next(cnt) - x, y = data[0], data[1] + x, y = data out = model(x) loss = criterion(out, y) return loss, {} n_epochs = 100 - config = TrainConfig(folder=tmp_path, max_iter=n_epochs, checkpoint_every=100, write_every=100) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + config = TrainConfig( + root_folder=tmp_path, max_iter=n_epochs, checkpoint_every=100, write_every=100 + ) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() assert next(cnt) == (n_epochs + 1) x = torch.rand(5, 1) @@ -103,13 +105,15 @@ def loss_fn(model: torch.nn.Module, xs: Any = None) -> tuple[torch.Tensor, dict] n_epochs = 50 config = TrainConfig( - folder=tmp_path, + root_folder=tmp_path, max_iter=n_epochs, print_every=5, checkpoint_every=100, write_every=100, ) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() assert next(cnt) == (n_epochs + 1) out = model() @@ -117,9 +121,9 @@ def loss_fn(model: torch.nn.Module, xs: Any = None) -> tuple[torch.Tensor, dict] @pytest.mark.flaky(max_runs=10) -def test_train_dictdataloader(tmp_path: Path, Basic: torch.nn.Module) -> None: +def test_train_val(tmp_path: Path, Basic: torch.nn.Module) -> None: batch_size = 25 - data = dictdataloader(batch_size=batch_size) + train_data, val_data = train_val_dataloaders(batch_size=batch_size) model = Basic cnt = count() @@ -128,17 +132,23 @@ def test_train_dictdataloader(tmp_path: Path, Basic: torch.nn.Module) -> None: def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: next(cnt) - x1, y1 = data["y1"][0], data["y1"][1] - x2, y2 = data["y2"][0], data["y2"][1] + x1, y1 = data l1 = criterion(model(x1), y1) - l2 = criterion(model(x2), y2) - return l1 + l2, {} + return l1, {} n_epochs = 100 config = TrainConfig( - folder=tmp_path, max_iter=n_epochs, print_every=10, checkpoint_every=100, write_every=100 + root_folder=tmp_path, + max_iter=n_epochs, + print_every=10, + checkpoint_every=100, + write_every=100, + ) + trainer = Trainer( + model, optimizer, config, loss_fn, train_dataloader=train_data, val_dataloader=val_data ) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + with trainer.enable_grad_opt(): + trainer.fit() assert next(cnt) == (n_epochs + 1) x = torch.rand(5, 1) @@ -152,6 +162,9 @@ def test_train_tensor_tuple(Basic: torch.nn.Module, BasicQNN: QNN) -> None: batch_size = 25 x = torch.linspace(0, 1, batch_size).reshape(-1, 1) y = torch.sin(x) + model = model.to( + torch.float32 + ) # BasicQNN might have float64, and Adam behaves weirdly with mixed precision cnt = count() criterion = torch.nn.MSELoss() @@ -159,7 +172,7 @@ def test_train_tensor_tuple(Basic: torch.nn.Module, BasicQNN: QNN) -> None: def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: next(cnt) - x, y = data[0], data[1] + x, y = data out = model(x) loss = criterion(out, y) return loss, {} @@ -172,7 +185,9 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d batch_size=batch_size, ) data = to_dataloader(x, y, batch_size=batch_size, infinite=True) - model, _ = train_with_grad(model, data, optimizer, config, loss_fn=loss_fn, dtype=dtype) + trainer = Trainer(model, optimizer, config, loss_fn, data, dtype=dtype) + with trainer.enable_grad_opt(): + model, _ = trainer.fit() assert next(cnt) == (n_epochs + 1) x = torch.rand(5, 1, dtype=torch.float32) @@ -212,7 +227,7 @@ def test_train_dataloader_val_check_and_non_dict_dataloader( def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: next(cnt) - x1, y1 = data["y1"][0], data["y1"][1] + x1, y1 = data loss = criterion(model(x1), y1) return loss, {} @@ -222,16 +237,18 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d config = get_train_config_validation(tmp_path, n_epochs, checkpoint_every, val_every) with pytest.raises(ValueError) as exc_info: - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + trainer = Trainer(model, optimizer, config, loss_fn, data) + with trainer.enable_grad_opt(): + trainer.fit() assert ( - "If `config.val_every` is provided as an integer, dataloader must" - "be an instance of `DictDataLoader`" in exc_info.exconly() + "If `config.val_every` is provided as an integer > 0, validation_dataloader" + "must be an instance of `DataLoader`." in exc_info.exconly() ) def test_train_dataloader_val_check_incorrect_keys(tmp_path: Path, Basic: torch.nn.Module) -> None: batch_size = 25 - data = dictdataloader(batch_size=batch_size, val=False) # Passing val=False to raise an error. + train_data, _ = train_val_dataloaders(batch_size=batch_size) model = Basic cnt = count() @@ -240,7 +257,7 @@ def test_train_dataloader_val_check_incorrect_keys(tmp_path: Path, Basic: torch. def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: next(cnt) - x1, y1 = data[0], data[1] + x1, y1 = data loss = criterion(model(x1), y1) return loss, {} @@ -250,17 +267,20 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d config = get_train_config_validation(tmp_path, n_epochs, checkpoint_every, val_every) with pytest.raises(ValueError) as exc_info: - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) + trainer = Trainer( + model, optimizer, config, loss_fn, train_dataloader=train_data, val_dataloader=None + ) + with trainer.enable_grad_opt(): + trainer.fit() assert ( - "If `config.val_every` is provided as an integer, the dictdataloader" - "must have `train` and `val` keys to access the respective dataloaders." - in exc_info.exconly() + "If `config.val_every` is provided as an integer > 0, validation_dataloader" + "must be an instance of `DataLoader`." in exc_info.exconly() ) -def test_train_dictdataloader_checkpoint_best_only(tmp_path: Path, Basic: torch.nn.Module) -> None: +def test_train_val_checkpoint_best_only(tmp_path: Path, Basic: torch.nn.Module) -> None: batch_size = 25 - data = dictdataloader(batch_size=batch_size, val=True) + train_data, val_data = train_val_dataloaders(batch_size=batch_size) model = Basic cnt = count() @@ -269,7 +289,7 @@ def test_train_dictdataloader_checkpoint_best_only(tmp_path: Path, Basic: torch. def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, dict]: next(cnt) - x1, y1 = data[0], data[1] + x1, y1 = data loss = criterion(model(x1), y1) return loss, {} @@ -278,10 +298,14 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d val_every = 10 config = get_train_config_validation(tmp_path, n_epochs, checkpoint_every, val_every) - train_with_grad(model, data, optimizer, config, loss_fn=loss_fn) - assert next(cnt) == 2 + n_epochs + n_epochs // val_every + trainer = Trainer( + model, optimizer, config, loss_fn, train_dataloader=train_data, val_dataloader=val_data + ) + with trainer.enable_grad_opt(): + trainer.fit() + assert next(cnt) == 2 + n_epochs + (n_epochs // val_every) + 1 # 1 for intial round 0 run - files = [f for f in os.listdir(tmp_path) if f.endswith(".pt") and "model" in f] + files = [f for f in os.listdir(trainer.config.log_folder) if f.endswith(".pt") and "model" in f] # Ideally it can be ensured if the (only) saved checkpoint is indeed the best, # but that is time-consuming since training must be run twice for comparison. # The below check may be plausible enough. diff --git a/tests/ml_tools/test_train_no_grad.py b/tests/ml_tools/test_train_no_grad.py index 229cdcb0a..036133dc0 100644 --- a/tests/ml_tools/test_train_no_grad.py +++ b/tests/ml_tools/test_train_no_grad.py @@ -10,7 +10,7 @@ import torch from torch.utils.data import DataLoader -from qadence.ml_tools import TrainConfig, num_parameters, to_dataloader, train_gradient_free +from qadence.ml_tools import TrainConfig, Trainer, num_parameters, to_dataloader # ensure reproducibility SEED = 42 @@ -43,12 +43,16 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d return loss, {} n_epochs = 500 - config = TrainConfig(folder=tmp_path, max_iter=n_epochs, checkpoint_every=100, write_every=100) + config = TrainConfig( + root_folder=tmp_path, max_iter=n_epochs, checkpoint_every=100, write_every=100 + ) optimizer = ng.optimizers.NGOpt(budget=config.max_iter, parametrization=num_parameters(model)) - train_gradient_free(model, data, optimizer, config, loss_fn=loss_fn) - assert next(cnt) == n_epochs + trainer = Trainer(model, None, config=config, loss_fn=loss_fn, train_dataloader=data) + with trainer.disable_grad_opt(optimizer): + trainer.fit() + assert next(cnt) == n_epochs + 1 x = torch.rand(5, 1) assert torch.allclose(torch.cos(x), model(x), rtol=1e-1, atol=1e-1)