diff --git a/qadence/ml_tools/callbacks/writer_registry.py b/qadence/ml_tools/callbacks/writer_registry.py index 8c02bb85..5bd07354 100644 --- a/qadence/ml_tools/callbacks/writer_registry.py +++ b/qadence/ml_tools/callbacks/writer_registry.py @@ -7,10 +7,7 @@ from typing import Any, Callable, Union from uuid import uuid4 -import mlflow from matplotlib.figure import Figure -from mlflow.entities import Run -from mlflow.models import infer_signature from torch import Tensor from torch.nn import Module from torch.utils.data import DataLoader @@ -43,7 +40,7 @@ class BaseWriter(ABC): log_model(model, dataloader): Logs the model and any relevant information. """ - run: Run # [attr-defined] + run: Any # [attr-defined] @abstractmethod def open(self, config: TrainConfig, iteration: int | None = None) -> Any: @@ -259,6 +256,8 @@ class MLFlowWriter(BaseWriter): """ def __init__(self) -> None: + from mlflow.entities import Run + self.run: Run self.mlflow: ModuleType @@ -274,6 +273,8 @@ def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType Returns: mlflow: The MLflow module instance. """ + import mlflow + self.mlflow = mlflow tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "") experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4())) @@ -369,6 +370,8 @@ def get_signature_from_dataloader( Returns: Optional[Any]: The inferred signature, if available. """ + from mlflow.models import infer_signature + if dataloader is None: return None