Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor][Feature] Train functions refactoring #593

Merged
merged 72 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
2713958
add callbacks manager class
mlahariya Oct 21, 2024
f412323
adding writer registery
mlahariya Oct 23, 2024
2eeb8c7
loss functions
mlahariya Oct 24, 2024
808adec
Adding callback manager
mlahariya Oct 30, 2024
fd90dd9
Add default loss functions
mlahariya Oct 30, 2024
f4833ee
cleanup
mlahariya Oct 30, 2024
1b2e302
Update mkdocs with Trainer Tutorials
mlahariya Nov 7, 2024
91d5289
Update api ml_tools.md
mlahariya Nov 7, 2024
b1e820e
Update custom-models.md
mlahariya Nov 7, 2024
fbc6d1b
Update tutorial analog-qubo
mlahariya Nov 7, 2024
a5f6fb8
Update tutorial dqc_1d
mlahariya Nov 7, 2024
49e974d
Update tutorial index
mlahariya Nov 7, 2024
57b9a72
Update tutorial qaoa
mlahariya Nov 7, 2024
744094b
Update tutorial qcl
mlahariya Nov 7, 2024
fe280ca
Update example quick start with Trainer
mlahariya Nov 7, 2024
9672a17
Add tutorial Trainer
mlahariya Nov 7, 2024
5c00c1e
Add tutorial data and configrations
mlahariya Nov 7, 2024
098d922
Add tutorial callbacks
mlahariya Nov 7, 2024
03d582e
Remove tutorial ml_tools
mlahariya Nov 7, 2024
72422f5
Update tests ml_tools test checkpointing with Trainer
mlahariya Nov 7, 2024
d309495
Update tests ml_tools test training with Trainer
mlahariya Nov 7, 2024
31e9f4d
Update tests ml_tools test logging with Trainer
mlahariya Nov 7, 2024
4dc08a1
Update tests ml_tools test grad free training with Trainer
mlahariya Nov 7, 2024
51d61f8
Update qadence ml_tools init
mlahariya Nov 7, 2024
a387779
Add qadence callbacks init
mlahariya Nov 7, 2024
8429467
Add ml_tools callbacks classes
mlahariya Nov 7, 2024
d9b8499
Add ml_tools callback manager
mlahariya Nov 7, 2024
d87029e
Move ml_tools saveload to callbacks
mlahariya Nov 7, 2024
b908117
Add ml_tools writers and writer registery
mlahariya Nov 7, 2024
bc02be9
Remove old callback commits
mlahariya Nov 7, 2024
9040976
Add train_utils init
mlahariya Nov 7, 2024
9768142
Add train_utils BaseTrainer
mlahariya Nov 7, 2024
32d1690
Add train_utils ConfigManager
mlahariya Nov 7, 2024
ea13696
Add ml_tools loss init
mlahariya Nov 7, 2024
cdc292a
Add ml_tools loss functions with get_loss_fun
mlahariya Nov 7, 2024
c9c6712
Update optimize step to also offer update_ng_parameters
mlahariya Nov 7, 2024
d6150d1
Update saveload location
mlahariya Nov 7, 2024
037ec3c
Update ml_tools TrainConfig
mlahariya Nov 7, 2024
30d1679
Remove train_grad
mlahariya Nov 7, 2024
51ff70f
Remove train_no_grad
mlahariya Nov 7, 2024
de28615
Add ml_tools Trainer
mlahariya Nov 7, 2024
e4c486e
Merge remote-tracking branch 'origin/main' into ML_570
Nov 7, 2024
0f2bb7d
Update tests and linting
mlahariya Nov 7, 2024
b8b97c3
Update docs to have successful build
mlahariya Nov 7, 2024
390a6a0
Final checks and cleanup
mlahariya Nov 7, 2024
257160d
Update suggested changes - Docs
mlahariya Nov 14, 2024
9842e02
Update Suggested Changes - ml_tools
mlahariya Nov 14, 2024
9d8c4d6
Updates minor doc and doc strings updates
mlahariya Nov 14, 2024
cbfe46d
Update default config callbacks
mlahariya Nov 14, 2024
32690ed
Update docs dataconfig with root folder
mlahariya Nov 14, 2024
a6d2d28
Update docs trainer with root folder
mlahariya Nov 14, 2024
cd54dbc
Update log_config with ml_tools richHandler logger
mlahariya Nov 14, 2024
9fe21e9
Update Traincofig root_folder
mlahariya Nov 14, 2024
55c771b
Update callbacks for root_folder
mlahariya Nov 14, 2024
5f99753
Update config manager with root folder and add warnings
mlahariya Nov 14, 2024
91c438d
Update saveload
mlahariya Nov 14, 2024
a9adbf2
Update Basetrainer logger
mlahariya Nov 14, 2024
dee69a4
Update trainer logger
mlahariya Nov 14, 2024
258cedd
Update writer registery with logfolder
mlahariya Nov 14, 2024
02b4a01
Update capsys stderr readout
mlahariya Nov 14, 2024
300f9f6
Minor cleanup
mlahariya Nov 14, 2024
dfa0651
Update docs with root/log folder
mlahariya Nov 19, 2024
ee64f62
Update trainer with setter for use_grad
mlahariya Nov 19, 2024
80afb1b
Update Config Manager
mlahariya Nov 19, 2024
625f17b
Update checkpointing tests
mlahariya Nov 19, 2024
27f221e
Add callback tests
mlahariya Nov 20, 2024
97e8744
Update Validation criterion use.
mlahariya Nov 20, 2024
56b9706
Update writer init in callback manager
mlahariya Nov 20, 2024
6fd235a
Update minor changes
mlahariya Nov 20, 2024
4506961
Update docs with infinite dataloader
mlahariya Nov 20, 2024
7254b26
Update changes in num_batches and inf dataloader
mlahariya Nov 20, 2024
5fa6e4c
Merge branch 'main' into ML_570
mlahariya Nov 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions docs/api/ml_tools.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
## ML Tools

This module implements gradient-free and gradient-based training loops for torch Modules and QuantumModel. It also implements the QNN class.
This module implements a `Trainer` class for torch `Modules` and `QuantumModel`. It also implements the `QNN` class and callbacks that can be used with the trainer module.


### ::: qadence.ml_tools.trainer

### ::: qadence.ml_tools.config

### ::: qadence.ml_tools.parameters

### ::: qadence.ml_tools.optimize_step

### ::: qadence.ml_tools.train_grad

### ::: qadence.ml_tools.train_no_grad

### ::: qadence.ml_tools.data

### ::: qadence.ml_tools.models

### ::: qadence.ml_tools.callbacks.callback

### ::: qadence.ml_tools.train_utils.base_trainer

### ::: qadence.ml_tools.callbacks.writer_registry
8 changes: 5 additions & 3 deletions docs/tutorials/advanced_tutorials/custom-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ This model can then be trained with the standard Qadence helper functions.

```python exec="on" source="material-block" result="json" session="custom-model"
from qadence import run
from qadence.ml_tools import train_with_grad, TrainConfig
from qadence.ml_tools import Trainer, TrainConfig
Trainer.set_use_grad(True)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
Expand All @@ -128,9 +129,10 @@ def loss_fn(model: LearnHadamard, _unused) -> tuple[torch.Tensor, dict]:
return loss, {}

config = TrainConfig(max_iter=2500)
model, optimizer = train_with_grad(
model, None, optimizer, config, loss_fn=loss_fn
trainer = Trainer(
model, optimizer, config, loss_fn
)
model, optimizer = trainer.fit()

wf_target = run(target_circuit)
assert torch.allclose(wf_target, model.wavefunction(), atol=1e-2)
Expand Down
11 changes: 7 additions & 4 deletions docs/tutorials/digital_analog_qc/analog-qubo.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ ensure the reproducibility of this tutorial.
import torch
from qadence import QuantumModel, QuantumCircuit, Register
from qadence import RydbergDevice, AnalogRX, AnalogRZ, chain
from qadence.ml_tools import train_gradient_free, TrainConfig, num_parameters
from qadence.ml_tools import Trainer, TrainConfig, num_parameters
import nevergrad as ng
import matplotlib.pyplot as plt

Expand All @@ -80,12 +80,12 @@ Q = np.array(
]
)

def loss(model: QuantumModel, *args) -> tuple[float, dict]:
def loss(model: QuantumModel, *args) -> tuple[torch.Tensor, dict]:
to_arr_fn = lambda bitstring: np.array(list(bitstring), dtype=int)
cost_fn = lambda arr: arr.T @ Q @ arr
samples = model.sample({}, n_shots=1000)[0] # extract samples
cost_fn = sum(samples[key] * cost_fn(to_arr_fn(key)) for key in samples)
return cost_fn / sum(samples.values()), {} # We return an optional metrics dict
return torch.tensor(cost_fn / sum(samples.values())), {} # We return an optional metrics dict
```

The QAOA algorithm needs a variational quantum circuit with optimizable parameters.
Expand Down Expand Up @@ -132,11 +132,14 @@ ML facilities to run gradient-free optimizations using the
[`nevergrad`](https://facebookresearch.github.io/nevergrad/) library.

```python exec="on" source="material-block" session="qubo"
Trainer.set_use_grad(False)

config = TrainConfig(max_iter=100)
optimizer = ng.optimizers.NGOpt(
budget=config.max_iter, parametrization=num_parameters(model)
)
train_gradient_free(model, None, optimizer, config, loss)
trainer = Trainer(model, optimizer, config, loss)
trainer.fit()

optimal_counts = model.sample({}, n_shots=1000)[0]
print(f"optimal_count = {optimal_counts}") # markdown-exec: hide
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/qml/dqc_1d.md
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ print(html_string(circuit)) # markdown-exec: hide

## Training the model

Now that the model is defined we can proceed with the training. the `QNN` class can be used like any other `torch.nn.Module`. Here we write a simple training loop, but you can also look at the [ml tools tutorial](ml_tools.md) to use the convenience training functions that Qadence provides.
Now that the model is defined we can proceed with the training. the `QNN` class can be used like any other `torch.nn.Module`. Here we write a simple training loop, but you can also look at the [ml tools tutorial](ml_tools/trainer.md) to use the convenience training functions that Qadence provides.

To train the model, we will select a random set of collocation points uniformly distributed within $-1.0< x <1.0$ and compute the loss function for those points.

Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/qml/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ differentiation via integration with [PyTorch](https://pytorch.org/) deep learni
Furthermore, Qadence offers a wide range of utilities for helping building and researching quantum machine learning algorithms, including:

* [a set of constructors](../../content/qml_constructors.md) for circuits commonly used in quantum machine learning such as feature maps and ansatze
* [a set of tools](ml_tools.md) for training and optimizing quantum neural networks and loading classical data into a QML algorithm
* [a set of tools](ml_tools/trainer.md) for training and optimizing quantum neural networks and loading classical data into a QML algorithm

## Some simple examples

Expand Down
Loading