Skip to content

Commit

Permalink
Update mlflow to be optional dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
mlahariya committed Dec 5, 2024
1 parent 0411419 commit 986f269
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions qadence/ml_tools/callbacks/writer_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,7 @@
from typing import Any, Callable, Union
from uuid import uuid4

import mlflow
from matplotlib.figure import Figure
from mlflow.entities import Run
from mlflow.models import infer_signature
from torch import Tensor
from torch.nn import Module
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -43,7 +40,7 @@ class BaseWriter(ABC):
log_model(model, dataloader): Logs the model and any relevant information.
"""

run: Run # [attr-defined]
run: Any # [attr-defined]

@abstractmethod
def open(self, config: TrainConfig, iteration: int | None = None) -> Any:
Expand Down Expand Up @@ -259,6 +256,8 @@ class MLFlowWriter(BaseWriter):
"""

def __init__(self) -> None:
from mlflow.entities import Run

self.run: Run
self.mlflow: ModuleType

Expand All @@ -274,6 +273,8 @@ def open(self, config: TrainConfig, iteration: int | None = None) -> ModuleType
Returns:
mlflow: The MLflow module instance.
"""
import mlflow

self.mlflow = mlflow
tracking_uri = os.getenv("MLFLOW_TRACKING_URI", "")
experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", str(uuid4()))
Expand Down Expand Up @@ -369,6 +370,8 @@ def get_signature_from_dataloader(
Returns:
Optional[Any]: The inferred signature, if available.
"""
from mlflow.models import infer_signature

if dataloader is None:
return None

Expand Down

0 comments on commit 986f269

Please sign in to comment.