Skip to content

Commit

Permalink
[Feature] Add mlflow experiment tracker (#450)
Browse files Browse the repository at this point in the history
Co-authored-by: debrevitatevitae <[email protected]>
Co-authored-by: seitzdom <[email protected]>
Co-authored-by: Roland-djee <[email protected]>
  • Loading branch information
4 people authored Jul 19, 2024
1 parent 08652a4 commit 5fa4641
Show file tree
Hide file tree
Showing 9 changed files with 778 additions and 35 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,7 @@ events.out.tfevents.*
*.dvi

*.gv

# mlflow
mlruns/
mlartifacts/
126 changes: 126 additions & 0 deletions docs/tutorials/qml/ml_tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -306,3 +306,129 @@ def train(

return model, optimizer
```

## Experiment tracking with mlflow

Qadence allows to track runs and log hyperparameters, models and plots with [tensorboard](https://pytorch.org/tutorials/recipes/recipes/tensorboard_with_pytorch.html) and [mlflow](https://mlflow.org/). In the following, we demonstrate the integration with mlflow.

### mlflow configuration
We have control over our tracking configuration by setting environment variables. First, let's look at the tracking URI. For the purpose of this demo we will be working with a local database, in a similar fashion as described [here](https://mlflow.org/docs/latest/tracking/tutorials/local-database.html),
```bash
export MLFLOW_TRACKING_URI=sqlite:///mlruns.db
```

Qadence can also read the following two environment variables to define the mlflow experiment name and run name
```bash
export MLFLOW_EXPERIMENT=test_experiment
export MLFLOW_RUN_NAME=run_0
```

If no tracking URI is provided, mlflow stores run information and artifacts in the local `./mlflow` directory and if no names are defined, the experiment and run will be named with random UUIDs.

### Setup
Let's do the necessary imports and declare a `DataLoader`. We can already define some hyperparameters here, including the seed for random number generators. mlflow can log hyperparameters with arbitrary types, for example the observable that we want to monitor (`Z` in this case, which has a `qadence.Operation` type).

```python
import random
from itertools import count

import numpy as np
import torch
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from torch.nn import Module
from torch.utils.data import DataLoader

from qadence import hea, QuantumCircuit, Z
from qadence.constructors import feature_map, hamiltonian_factory
from qadence.ml_tools import train_with_grad, TrainConfig
from qadence.ml_tools.data import to_dataloader
from qadence.ml_tools.utils import rand_featureparameters
from qadence.models import QNN, QuantumModel
from qadence.types import ExperimentTrackingTool

hyperparams = {
"seed": 42,
"batch_size": 10,
"n_qubits": 2,
"ansatz_depth": 1,
"observable": Z,
}

np.random.seed(hyperparams["seed"])
torch.manual_seed(hyperparams["seed"])
random.seed(hyperparams["seed"])


def dataloader(batch_size: int = 25) -> DataLoader:
x = torch.linspace(0, 1, batch_size).reshape(-1, 1)
y = torch.cos(x)
return to_dataloader(x, y, batch_size=batch_size, infinite=True)
```

We continue with the regular QNN definition, together with the loss function and optimizer.

```python
obs = hamiltonian_factory(register=hyperparams["n_qubits"], detuning=hyperparams["observable"])

data = dataloader(hyperparams["batch_size"])
fm = feature_map(hyperparams["n_qubits"], param="x")

model = QNN(
QuantumCircuit(
hyperparams["n_qubits"], fm, hea(hyperparams["n_qubits"], hyperparams["ansatz_depth"])
),
observable=obs,
inputs=["x"],
)

cnt = count()
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.1)

inputs = rand_featureparameters(model, 1)

def loss_fn(model: QuantumModel, data: torch.Tensor) -> tuple[torch.Tensor, dict]:
next(cnt)
out = model.expectation(inputs)
loss = criterion(out, torch.rand(1))
return loss, {}
```

### `TrainConfig` specifications
Qadence offers different tracking options via `TrainConfig`. Here we use the `ExperimentTrackingTool` type to specify that we want to track the experiment with mlflow. Tracking with tensorboard is also possible. We can then indicate *what* and *how often* we want to track or log. `write_every` controls the number of epochs after which the loss values is logged. Thanks to the `plotting_functions` and `plot_every`arguments, we are also able to plot model-related quantities throughout training. Notice that arbitrary plotting functions can be passed, as long as the signature is the same as `plot_fn` below. Finally, the trained model can be logged by setting `log_model=True`. Here is an example of plotting function and training configuration

```python
def plot_fn(model: Module, iteration: int) -> tuple[str, Figure]:
descr = f"ufa_prediction_epoch_{iteration}.png"
fig, ax = plt.subplots()
x = torch.linspace(0, 1, 100).reshape(-1, 1)
out = model.expectation(x)
ax.plot(x.detach().numpy(), out.detach().numpy())
return descr, fig


config = TrainConfig(
folder="mlflow_demonstration",
max_iter=10,
checkpoint_every=1,
plot_every=2,
write_every=1,
log_model=True,
tracking_tool=ExperimentTrackingTool.MLFLOW,
hyperparams=hyperparams,
plotting_functions=(plot_fn,),
)
```

### Training and inspecting
Model training happens as usual
```python
train_with_grad(model, data, optimizer, config, loss_fn=loss_fn)
```

After training , we can inspect our experiment via the mlflow UI
```bash
mlflow ui --port 8080 --backend-store-uri sqlite:///mlruns.db
```
In this case, since we're running on a local server, we can access the mlflow UI by navigating to http://localhost:8080/.
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,11 @@ authors = [
{ name = "Smit Chaudhary", email = "[email protected]" },
{ name = "Ignacio Fernández Graña", email = "[email protected]" },
{ name = "Charles Moussa", email = "[email protected]" },
{ name = "Giorgio Tosti Balducci", email = "[email protected]" },
]
requires-python = ">=3.9"
license = { text = "Apache 2.0" }
version = "1.7.2"
version = "1.7.3"
classifiers = [
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python",
Expand Down Expand Up @@ -83,8 +84,8 @@ horqrux = [
protocols = ["qadence-protocols"]
libs = ["qadence-libs"]
dlprof = ["nvidia-pyindex", "nvidia-dlprof[pytorch]"]
all = ["pulser", "braket", "visualization", "protocols", "libs"]

mlflow = ["mlflow"]
all = ["pulser", "braket", "visualization", "protocols", "libs", "mlflow"]

[tool.hatch.envs.default]
dependencies = [
Expand All @@ -102,7 +103,7 @@ dependencies = [
"ruff",
"pydocstringformatter",
]
features = ["pulser", "braket", "visualization", "horqrux"]
features = ["pulser", "braket", "visualization", "horqrux", "mlflow"]

[tool.hatch.envs.default.scripts]
test = "pytest -n auto --cov-report lcov --cov-config=pyproject.toml --cov=qadence --cov=tests --ignore=./tests/test_examples.py {args}"
Expand Down Expand Up @@ -139,7 +140,7 @@ dependencies = [
"markdown-exec",
"mike",
]
features = ["pulser", "braket", "horqrux", "visualization"]
features = ["pulser", "braket", "horqrux", "visualization", "mlflow"]

[tool.hatch.envs.docs.scripts]
build = "mkdocs build --clean --strict"
Expand Down
108 changes: 102 additions & 6 deletions qadence/ml_tools/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,25 @@
from dataclasses import dataclass, field, fields
from logging import getLogger
from pathlib import Path
from typing import Callable, Optional, Type
from typing import Callable, Type
from uuid import uuid4

from sympy import Basic
from torch import Tensor

from qadence.blocks.analog import AnalogBlock
from qadence.blocks.primitive import ParametricBlock
from qadence.operations import RX, AnalogRX
from qadence.parameters import Parameter
from qadence.types import AnsatzType, BasisSet, MultivariateStrategy, ReuploadScaling, Strategy
from qadence.types import (
AnsatzType,
BasisSet,
ExperimentTrackingTool,
LoggablePlotFunction,
MultivariateStrategy,
ReuploadScaling,
Strategy,
)

logger = getLogger(__file__)

Expand All @@ -37,10 +47,14 @@ class TrainConfig:
print_every: int = 1000
"""Print loss/metrics."""
write_every: int = 50
"""Write tensorboard logs."""
"""Write loss and metrics with the tracking tool."""
checkpoint_every: int = 5000
"""Write model/optimizer checkpoint."""
folder: Optional[Path] = None
plot_every: int = 5000
"""Write figures."""
log_model: bool = False
"""Logs a serialised version of the model."""
folder: Path | None = None
"""Checkpoint/tensorboard logs folder."""
create_subfolder_per_run: bool = False
"""Checkpoint/tensorboard logs stored in subfolder with name `<timestamp>_<PID>`.
Expand All @@ -59,14 +73,38 @@ class TrainConfig:
validation loss across previous iterations.
"""
validation_criterion: Optional[Callable] = None
validation_criterion: Callable | None = None
"""A boolean function which evaluates a given validation metric is satisfied."""
trainstop_criterion: Optional[Callable] = None
trainstop_criterion: Callable | None = None
"""A boolean function which evaluates a given training stopping metric is satisfied."""
batch_size: int = 1
"""The batch_size to use when passing a list/tuple of torch.Tensors."""
verbose: bool = True
"""Whether or not to print out metrics values during training."""
tracking_tool: ExperimentTrackingTool = ExperimentTrackingTool.TENSORBOARD
"""The tracking tool of choice."""
hyperparams: dict = field(default_factory=dict)
"""Hyperparameters to track."""
plotting_functions: tuple[LoggablePlotFunction, ...] = field(default_factory=tuple) # type: ignore
"""Functions for in-train plotting."""

# tensorboard only allows for certain types as hyperparameters
_tb_allowed_hyperparams_types: tuple = field(
default=(int, float, str, bool, Tensor), init=False, repr=False
)

def _filter_tb_hyperparams(self) -> None:
keys_to_remove = [
key
for key, value in self.hyperparams.items()
if not isinstance(value, TrainConfig._tb_allowed_hyperparams_types)
]
if keys_to_remove:
logger.warning(
f"Tensorboard cannot log the following hyperparameters: {keys_to_remove}."
)
for key in keys_to_remove:
self.hyperparams.pop(key)

def __post_init__(self) -> None:
if self.folder:
Expand All @@ -81,6 +119,64 @@ def __post_init__(self) -> None:
self.trainstop_criterion = lambda x: x <= self.max_iter
if self.validation_criterion is None:
self.validation_criterion = lambda *x: False
if self.hyperparams and self.tracking_tool == ExperimentTrackingTool.TENSORBOARD:
self._filter_tb_hyperparams()
if self.tracking_tool == ExperimentTrackingTool.MLFLOW:
self._mlflow_config = MLFlowConfig()
if self.plotting_functions and self.tracking_tool != ExperimentTrackingTool.MLFLOW:
logger.warning("In-training plots are only available with mlflow tracking.")
if not self.plotting_functions and self.tracking_tool == ExperimentTrackingTool.MLFLOW:
logger.warning("Tracking with mlflow, but no plotting functions provided.")

@property
def mlflow_config(self) -> MLFlowConfig:
if self.tracking_tool == ExperimentTrackingTool.MLFLOW:
return self._mlflow_config
else:
raise AttributeError(
"mlflow_config is available only for with the mlflow tracking tool."
)


class MLFlowConfig:
"""
Configuration for mlflow tracking.
Example:
export MLFLOW_TRACKING_URI=tracking_uri
export MLFLOW_EXPERIMENT=experiment_name
export MLFLOW_RUN_NAME=run_name
"""

def __init__(self) -> None:
import mlflow

self.tracking_uri: str = os.getenv("MLFLOW_TRACKING_URI", "")
"""The URI of the mlflow tracking server.
An empty string, or a local file path, prefixed with file:/.
Data is stored locally at the provided file (or ./mlruns if empty).
"""

self.experiment_name: str = os.getenv("MLFLOW_EXPERIMENT", str(uuid4()))
"""The name of the experiment.
If None or empty, a new experiment is created with a random UUID.
"""

self.run_name: str = os.getenv("MLFLOW_RUN_NAME", str(uuid4()))
"""The name of the run."""

mlflow.set_tracking_uri(self.tracking_uri)

# activate existing or create experiment
exp_filter_string = f"name = '{self.experiment_name}'"
if not mlflow.search_experiments(filter_string=exp_filter_string):
mlflow.create_experiment(name=self.experiment_name)

self.experiment = mlflow.set_experiment(self.experiment_name)
self.run = mlflow.start_run(run_name=self.run_name, nested=False)


@dataclass
Expand Down
Loading

0 comments on commit 5fa4641

Please sign in to comment.