Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add additional callbacks #633

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions qadence/ml_tools/callbacks/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,3 +449,60 @@
writer.log_model(
model, trainer.train_dataloader, trainer.val_dataloader, trainer.test_dataloader
)

class LRSchedulerExponentialDecay(Callback):
"""
Applies exponential decay to the learning rate during training.

This callback adjusts the learning rate at regular intervals by multiplying
it with a decay factor. The learning rate is updated as:
lr = lr * gamma

Example Usage in `TrainConfig`:
To use `LRSchedulerExponentialDecay`, include it in the `callbacks` list
when setting up your `TrainConfig`:
```python exec="on" source="material-block" result="json"
from qadence.ml_tools import TrainConfig
from qadence.ml_tools.callbacks import LRSchedulerExponentialDecay

# Create an instance of the LRSchedulerExponentialDecay callback
lr_exponential_decay = LRSchedulerExponentialDecay(on = "train_epoch_end",
called_every = 100,
gamma = 0.9)

config = TrainConfig(
max_iter=10000,
# Print metrics every 1000 training epochs
print_every=1000,
# Add the custom callback that runs every 100 val_batch_end
callbacks=[lr_exponential_decay]
)
```
"""

def __init__(self, on: str, called_every: int, gamma: float = 0.9):
"""Initializes the LRSchedulerExponentialDecay callback.

Args:
on (str): The event to trigger the callback.
called_every (int): Frequency of callback calls in terms of iterations.
gamma (float, optional): The decay factor applied to the learning rate.
A value < 1 reduces the learning rate over time.
Default is 0.9.
"""
super().__init__(on=on, called_every=called_every)

Check warning on line 493 in qadence/ml_tools/callbacks/callback.py

View check run for this annotation

Codecov / codecov/patch

qadence/ml_tools/callbacks/callback.py#L493

Added line #L493 was not covered by tests
if gamma > 1:
raise ValueError(f"Gamma must be less than or equal to 1, but got {gamma}.")
self.gamma = gamma

Check warning on line 496 in qadence/ml_tools/callbacks/callback.py

View check run for this annotation

Codecov / codecov/patch

qadence/ml_tools/callbacks/callback.py#L495-L496

Added lines #L495 - L496 were not covered by tests

def run_callback(self, trainer: Any, config: TrainConfig, writer: BaseWriter) -> None:
"""
Runs the callback to apply exponential decay to the learning rate.

Args:
trainer (Any): The training object.
config (TrainConfig): The configuration object.
writer (BaseWriter): The writer object for logging.
"""
for param_group in trainer.optimizer.param_groups:
param_group["lr"] *= self.gamma

Check warning on line 508 in qadence/ml_tools/callbacks/callback.py

View check run for this annotation

Codecov / codecov/patch

qadence/ml_tools/callbacks/callback.py#L508

Added line #L508 was not covered by tests
Loading