Skip to content

Commit

Permalink
[bug] Resume dictdataloader support for Trainer (#627)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlahariya authored Dec 5, 2024
1 parent cf4d6ac commit ad4051b
Show file tree
Hide file tree
Showing 4 changed files with 185 additions and 87 deletions.
61 changes: 36 additions & 25 deletions qadence/ml_tools/callbacks/writer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,14 @@
from typing import Any, Callable, Union
from uuid import uuid4

import mlflow
from matplotlib.figure import Figure
from mlflow.entities import Run
from mlflow.models import infer_signature
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from qadence.ml_tools.config import TrainConfig
from qadence.ml_tools.data import OptimizeResult
from qadence.ml_tools.data import DictDataLoader, OptimizeResult
from qadence.types import ExperimentTrackingTool

logger = getLogger("ml_tools")
Expand All @@ -43,7 +40,7 @@ class BaseWriter(ABC):
log_model(model, dataloader): Logs the model and any relevant information.
"""

run: Run # [attr-defined]
run: Any # [attr-defined]

@abstractmethod
def open(self, config: TrainConfig, iteration: int | None = None) -> Any:
Expand Down Expand Up @@ -104,18 +101,18 @@ def plot(
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model and associated data.
Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
raise NotImplementedError("Writers must implement a log_model method.")

Expand Down Expand Up @@ -231,9 +228,9 @@ def plot(
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model.
Expand All @@ -242,9 +239,9 @@ def log_model(
Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
logger.warning("Model logging is not supported by tensorboard. No model will be logged.")

Expand All @@ -259,6 +256,14 @@ class MLFlowWriter(BaseWriter):
"""

def __init__(self) -> None:
try:
from mlflow.entities import Run
except ImportError:
raise ImportError(
"mlflow is not installed. Please install qadence with the mlflow feature: "
"`pip install qadence[mlflow]`."
)

self.run: Run
self.mlflow: ModuleType

Expand All @@ -274,6 +279,8 @@ def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType
Returns:
mlflow: The MLflow module instance.
"""
import mlflow

self.mlflow = mlflow
tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4()))
Expand Down Expand Up @@ -356,17 +363,21 @@ def plot(
"Please call the 'writer.open()' method before writing"
)

def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader | None) -> Any:
def get_signature_from_dataloader(
self, model: Module, dataloader: DataLoader | DictDataLoader | None
) -> Any:
"""
Infers the signature of the model based on the input data from the dataloader.
Args:
model (Module): The model to use for inference.
dataloader (DataLoader | None): DataLoader for model inputs.
dataloader (DataLoader | DictDataLoader | None): DataLoader for model inputs.
Returns:
Optional[Any]: The inferred signature, if available.
"""
from mlflow.models import infer_signature

if dataloader is None:
return None

Expand All @@ -384,18 +395,18 @@ def get_signature_from_dataloader(self, model: Module, dataloader: DataLoader |
def log_model(
self,
model: Module,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
) -> None:
"""
Logs the model and its signature to MLflow using the provided data loaders.
Args:
model (Module): The model to log.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (DataLoader | DictDataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | DictDataLoader | None): DataLoader for testing data.
"""
if not self.mlflow:
raise RuntimeError(
Expand Down
59 changes: 33 additions & 26 deletions qadence/ml_tools/train_utils/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import torch
from nevergrad.optimization.base import Optimizer as NGOptimizer
from torch import nn, optim
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, TensorDataset

from qadence.ml_tools.callbacks import CallbacksManager
from qadence.ml_tools.config import TrainConfig
from qadence.ml_tools.data import InfiniteTensorDataset
from qadence.ml_tools.data import DictDataLoader
from qadence.ml_tools.loss import get_loss_fn
from qadence.ml_tools.optimize_step import optimize_step
from qadence.ml_tools.parameters import get_parameters
Expand Down Expand Up @@ -42,9 +42,9 @@ class BaseTrainer:
model (nn.Module): The neural network model.
optimizer (optim.Optimizer | NGOptimizer | None): The optimizer for training.
config (TrainConfig): The configuration settings for training.
train_dataloader (DataLoader | None): DataLoader for training data.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
optimize_step (Callable): Function for performing an optimization step.
loss_fn (Callable | str ]): loss function to use. Default loss function
Expand All @@ -69,9 +69,9 @@ def __init__(
config: TrainConfig,
loss_fn: str | Callable = "mse",
optimize_step: Callable = optimize_step,
train_dataloader: DataLoader | None = None,
val_dataloader: DataLoader | None = None,
test_dataloader: DataLoader | None = None,
train_dataloader: DataLoader | DictDataLoader | None = None,
val_dataloader: DataLoader | DictDataLoader | None = None,
test_dataloader: DataLoader | DictDataLoader | None = None,
max_batches: int | None = None,
):
"""
Expand All @@ -86,11 +86,11 @@ def __init__(
str input to be specified to use a default loss function.
currently supported loss functions: 'mse', 'cross_entropy'.
If not specified, default mse loss will be used.
train_dataloader (DataLoader | None): DataLoader for training data.
train_dataloader (Dataloader | DictDataLoader | None): DataLoader for training data.
If the model does not need data to evaluate loss, no dataset
should be provided.
val_dataloader (DataLoader | None): DataLoader for validation data.
test_dataloader (DataLoader | None): DataLoader for testing data.
val_dataloader (Dataloader | DictDataLoader | None): DataLoader for validation data.
test_dataloader (Dataloader | DictDataLoader | None): DataLoader for testing data.
max_batches (int | None): Maximum number of batches to process per epoch.
This is only valid in case of finite TensorDataset dataloaders.
if max_batches is not None, the maximum number of batches used will
Expand All @@ -100,9 +100,9 @@ def __init__(
self._model: nn.Module
self._optimizer: optim.Optimizer | NGOptimizer | None
self._config: TrainConfig
self._train_dataloader: DataLoader | None = None
self._val_dataloader: DataLoader | None = None
self._test_dataloader: DataLoader | None = None
self._train_dataloader: DataLoader | DictDataLoader | None = None
self._val_dataloader: DataLoader | DictDataLoader | None = None
self._test_dataloader: DataLoader | DictDataLoader | None = None

self.config = config
self.model = model
Expand Down Expand Up @@ -311,7 +311,7 @@ def config(self, value: TrainConfig) -> None:
self.callback_manager = CallbacksManager(value)
self.config_manager = ConfigManager(value)

def _compute_num_batches(self, dataloader: DataLoader) -> int:
def _compute_num_batches(self, dataloader: DataLoader | DictDataLoader) -> int:
"""
Computes the number of batches for the given DataLoader.
Expand All @@ -321,34 +321,41 @@ def _compute_num_batches(self, dataloader: DataLoader) -> int:
"""
if dataloader is None:
return 1
dataset = dataloader.dataset
if isinstance(dataset, InfiniteTensorDataset):
return 1
if isinstance(dataloader, DictDataLoader):
dataloader_name, dataloader_value = list(dataloader.dataloaders.items())[0]
dataset = dataloader_value.dataset
batch_size = dataloader_value.batch_size
else:
n_batches = int(
(dataset.tensors[0].size(0) + dataloader.batch_size - 1) // dataloader.batch_size
)
dataset = dataloader.dataset
batch_size = dataloader.batch_size

if isinstance(dataset, TensorDataset):
n_batches = int((dataset.tensors[0].size(0) + batch_size - 1) // batch_size)
return min(self.max_batches, n_batches) if self.max_batches is not None else n_batches
else:
return 1

def _validate_dataloader(self, dataloader: DataLoader, dataloader_type: str) -> None:
def _validate_dataloader(
self, dataloader: DataLoader | DictDataLoader, dataloader_type: str
) -> None:
"""
Validates the type of the DataLoader and raises errors for unsupported types.
Args:
dataloader (DataLoader): The DataLoader to validate.
dataloader (DataLoader | DictDataLoader): The DataLoader to validate.
dataloader_type (str): The type of DataLoader ("train", "val", or "test").
"""
if dataloader is not None:
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, (DataLoader, DictDataLoader)):
raise NotImplementedError(
f"Unsupported dataloader type: {type(dataloader)}."
"The dataloader must be an instance of DataLoader."
)
if dataloader_type == "val" and self.config.val_every > 0:
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, (DataLoader, DictDataLoader)):
raise ValueError(
"If `config.val_every` is provided as an integer > 0, validation_dataloader"
"must be an instance of `DataLoader`."
"must be an instance of `DataLoader` or `DictDataLoader`."
)

@staticmethod
Expand Down
Loading

0 comments on commit ad4051b

Please sign in to comment.