Skip to content

Commit

Permalink
Updates minor doc and doc strings updates
Browse files Browse the repository at this point in the history
  • Loading branch information
mlahariya committed Nov 14, 2024
1 parent 9842e02 commit 9d8c4d6
Show file tree
Hide file tree
Showing 11 changed files with 167 additions and 316 deletions.
1 change: 1 addition & 0 deletions docs/tutorials/advanced_tutorials/custom-models.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ This model can then be trained with the standard Qadence helper functions.
```python exec="on" source="material-block" result="json" session="custom-model"
from qadence import run
from qadence.ml_tools import Trainer, TrainConfig
Trainer.set_use_grad(True)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)
Expand Down
42 changes: 24 additions & 18 deletions docs/tutorials/qml/ml_tools/callbacks.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Qadence ml_tools offers several built-in callbacks for common tasks like saving

Prints metrics at specified intervals.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import PrintMetrics

Expand All @@ -38,7 +38,8 @@ config = TrainConfig(

Writes metrics to a specified logging destination.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import WriteMetrics

write_metrics_callback = WriteMetrics(on="train_epoch_end", called_every=50)
Expand All @@ -53,7 +54,8 @@ config = TrainConfig(

Plots metrics based on user-defined plotting functions.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import PlotMetrics

plot_metrics_callback = PlotMetrics(on="train_epoch_end", called_every=100)
Expand All @@ -68,7 +70,8 @@ config = TrainConfig(

Logs hyperparameters to keep track of training settings.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import LogHyperparameters

log_hyper_callback = LogHyperparameters(on="train_start", called_every=1)
Expand All @@ -83,7 +86,8 @@ config = TrainConfig(

Saves model checkpoints at specified intervals.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import SaveCheckpoint

save_checkpoint_callback = SaveCheckpoint(on="train_epoch_end", called_every=100)
Expand All @@ -98,7 +102,8 @@ config = TrainConfig(

Saves the best model checkpoint based on a validation criterion.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import SaveBestCheckpoint

save_best_checkpoint_callback = SaveBestCheckpoint(on="val_epoch_end", called_every=10)
Expand All @@ -113,7 +118,8 @@ config = TrainConfig(

Loads a saved model checkpoint at the start of training.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import LoadCheckpoint

load_checkpoint_callback = LoadCheckpoint(on="train_start")
Expand All @@ -128,7 +134,8 @@ config = TrainConfig(

Logs the model structure and parameters.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import LogModelTracker

log_model_callback = LogModelTracker(on="train_end")
Expand All @@ -152,7 +159,7 @@ There are two main ways to define a callback:

#### Example 1: Providing a Callback Function Directly

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools.callbacks import Callback

# Define a custom callback function
Expand All @@ -161,39 +168,38 @@ def custom_callback_function(trainer, config, writer):

# Create the callback instance
custom_callback = Callback(
on="on_train_end",
called_every=5,
on="train_end",
callback=custom_callback_function
)
```

#### Example 2: Subclassing the Callback

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools.callbacks import Callback

class CustomCallback(Callback):
def run_callback(self, trainer, config, writer):
print("Custom behavior in run_callback method.")

# Create the subclassed callback instance
custom_callback = CustomCallback(on="on_train_end", called_every=10)
custom_callback = CustomCallback(on="train_batch_end", called_every=10)
```


## 3. Adding Callbacks to `TrainConfig`

To use callbacks in `TrainConfig`, add them to the `callbacks` list when configuring the training process.

```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import SaveCheckpoint, PrintMetrics

config = TrainConfig(
max_iter=10000,
callbacks=[
SaveCheckpoint(on="on_val_epoch_end", called_every=50),
PrintMetrics(on="on_train_epoch_end", called_every=100),
SaveCheckpoint(on="val_epoch_end", called_every=50),
PrintMetrics(on="train_epoch_end", called_every=100),
]
)
```
Expand All @@ -217,9 +223,9 @@ These defaults handle common needs, but you can also add custom callbacks to any
To create a custom `Trainer` that includes a `PrintMetrics` callback executed specifically at the end of each epoch, follow the steps below.


```python
```python exec="on" source="material-block" html="1"
from qadence.ml_tools.trainer import Trainer
from qadence.ml_tools.callback import PrintMetrics
from qadence.ml_tools.callbacks import PrintMetrics

class CustomTrainer(Trainer):
def __init__(self, *args, **kwargs):
Expand Down
4 changes: 2 additions & 2 deletions docs/tutorials/qml/ml_tools/data_and_config.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ n_epochs = 100
print_parameters = lambda opt_res: print(opt_res.model.parameters())
condition_print = lambda opt_res: opt_res.loss < 1.0e-03
modify_extra_opt_res = {"n_epochs": n_epochs}
custom_callback = Callback( on="on_train_end", callback = print_parameters, callback_condition=condition_print, modify_optimize_result=modify_extra_opt_res, called_every=10,)
custom_callback = Callback(on="train_end", callback = print_parameters, callback_condition=condition_print, modify_optimize_result=modify_extra_opt_res, called_every=10,)

config = TrainConfig(
folder="some_path/",
Expand Down Expand Up @@ -170,7 +170,7 @@ def callback_fn(trainer, config, writer):
if trainer.opt_res.loss < 0.001:
print("Custom Callback: Loss threshold reached!")

custom_callback = Callback(on = "on_train_epoch_end", called_every = 10, callback_function = callback_fn )
custom_callback = Callback(on = "train_epoch_end", called_every = 10, callback_function = callback_fn )

config = TrainConfig(callbacks=[custom_callback])
```
Expand Down
16 changes: 8 additions & 8 deletions qadence/ml_tools/callbacks/callback.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, Callable, Optional, Union
from typing import Any, Callable

from qadence.ml_tools.callbacks.saveload import load_checkpoint, write_checkpoint
from qadence.ml_tools.callbacks.writer_registry import BaseWriter
Expand All @@ -24,9 +24,9 @@ class Callback:
"val_batch_start", "val_batch_end", "test_batch_start",
"test_batch_end"]
called_every (int): Frequency of callback calls in terms of iterations.
callback (Optional[CallbackFunction]): The function to call if the condition is met.
callback_condition (Optional[CallbackConditionFunction]): Condition to check before calling.
modify_optimize_result (Optional[Union[CallbackFunction, dict[str, Any]]]):
callback (CallbackFunction | None): The function to call if the condition is met.
callback_condition (CallbackConditionFunction | None): Condition to check before calling.
modify_optimize_result (CallbackFunction | dict[str, Any] | None):
Function to modify `OptimizeResult`.
A callback can be defined in two ways:
Expand Down Expand Up @@ -81,14 +81,14 @@ def __init__(
self,
on: str | TrainingStage = "idle",
called_every: int = 1,
callback: Union[CallbackFunction, None] = None,
callback_condition: Union[CallbackConditionFunction, None] = None,
modify_optimize_result: Optional[Union[CallbackFunction, dict[str, Any]]] = None,
callback: CallbackFunction | None = None,
callback_condition: CallbackConditionFunction | None = None,
modify_optimize_result: CallbackFunction | dict[str, Any] | None = None,
):
if not isinstance(called_every, int):
raise ValueError("called_every must be a positive integer or 0")

self.callback: Union[CallbackFunction, None] = callback
self.callback: CallbackFunction | None = callback
self.on: str | TrainingStage = on
self.called_every: int = called_every
self.callback_condition = callback_condition or (lambda _: True)
Expand Down
4 changes: 2 additions & 2 deletions qadence/ml_tools/callbacks/writer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class BaseWriter(ABC):
run: Run # [attr-defined]

@abstractmethod
def open(self, config: TrainConfig, iteration: int = None) -> Any:
def open(self, config: TrainConfig, iteration: int | None = None) -> Any:
"""
Opens the writer and prepares it for logging.
Expand Down Expand Up @@ -262,7 +262,7 @@ def __init__(self) -> None:
self.run: Run
self.mlflow: ModuleType

def open(self, config: TrainConfig, iteration: int = None) -> ModuleType | None:
def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType | None:
"""
Opens the MLflow writer and initializes an MLflow run.
Expand Down
18 changes: 9 additions & 9 deletions qadence/ml_tools/loss/loss.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

from typing import Callable, Dict, Union
from typing import Callable

import torch
import torch.nn as nn


def mse_loss(
model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor]
) -> tuple[torch.Tensor, Dict[str, float]]:
) -> tuple[torch.Tensor, dict[str, float]]:
"""Computes the Mean Squared Error (MSE) loss between model predictions and targets.
Args:
Expand All @@ -18,9 +18,9 @@ def mse_loss(
- targets (torch.Tensor): The ground truth labels.
Returns:
Tuple[torch.Tensor, Dict[str, float]]:
Tuple[torch.Tensor, dict[str, float]]:
- loss (torch.Tensor): The computed MSE loss value.
- metrics (Dict[str, float]): A dictionary with the MSE loss value.
- metrics (dict[str, float]): A dictionary with the MSE loss value.
"""
criterion = nn.MSELoss()
inputs, targets = batch
Expand All @@ -33,7 +33,7 @@ def mse_loss(

def cross_entropy_loss(
model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor]
) -> tuple[torch.Tensor, Dict[str, float]]:
) -> tuple[torch.Tensor, dict[str, float]]:
"""Computes the Cross Entropy loss between model predictions and targets.
Args:
Expand All @@ -43,9 +43,9 @@ def cross_entropy_loss(
- targets (torch.Tensor): The ground truth labels.
Returns:
Tuple[torch.Tensor, Dict[str, float]]:
Tuple[torch.Tensor, dict[str, float]]:
- loss (torch.Tensor): The computed Cross Entropy loss value.
- metrics (Dict[str, float]): A dictionary with the Cross Entropy loss value.
- metrics (dict[str, float]): A dictionary with the Cross Entropy loss value.
"""
criterion = nn.CrossEntropyLoss()
inputs, targets = batch
Expand All @@ -56,12 +56,12 @@ def cross_entropy_loss(
return loss, metrics


def get_loss_fn(loss_fn: Union[None, Callable, str]) -> Callable:
def get_loss_fn(loss_fn: str | Callable | None) -> Callable:
"""
Returns the appropriate loss function based on the input argument.
Args:
loss_fn (Union[None, Callable, str]): The loss function to use.
loss_fn (str | Callable | None): The loss function to use.
- If `loss_fn` is a callable, it will be returned directly.
- If `loss_fn` is a string, it should be one of:
- "mse": Returns the `mse_loss` function.
Expand Down
Loading

0 comments on commit 9d8c4d6

Please sign in to comment.