Skip to content

Commit

Permalink
[fix] Fix pytorch_lightning aliases issue (#2747)
Browse files Browse the repository at this point in the history
  • Loading branch information
popfido authored Nov 2, 2023
1 parent ec08b8d commit 1d0e827
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
14 changes: 12 additions & 2 deletions aim/sdk/adapters/pytorch_lightning.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import os
import importlib.util
from typing import Any, Dict, Optional, Union
from argparse import Namespace

import packaging.version

try:
if importlib.util.find_spec("lightning"):
import lightning.pytorch as pl

from lightning.pytorch.loggers.logger import (
Logger, rank_zero_experiment
)

from lightning.pytorch.utilities import rank_zero_only
elif importlib.util.find_spec("pytorch_lightning"):
import pytorch_lightning as pl

if packaging.version.parse(pl.__version__) < packaging.version.parse("1.7"):
Expand All @@ -19,10 +28,11 @@
)

from pytorch_lightning.utilities import rank_zero_only
except ImportError:
else:
raise RuntimeError(
'This contrib module requires PyTorch Lightning to be installed. '
'Please install it with command: \n pip install pytorch-lightning'
'or \n pip install lightning'
)

from aim.sdk.run import Run
Expand Down
12 changes: 11 additions & 1 deletion examples/pytorch_lightning_track.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
import importlib.util
from aim.pytorch_lightning import AimLogger

from argparse import ArgumentParser

import torch
import pytorch_lightning as pl
if importlib.util.find_spec("lightning"):
import lightning.pytorch as pl
elif importlib.util.find_spec("pytorch_lightning"): # noqa F401
import pytorch_lightning as pl
else:
raise RuntimeError(
'This contrib module requires PyTorch Lightning to be installed. '
'Please install it with command: \n pip install pytorch-lightning \n'
'or \n pip install lightning'
)
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split

Expand Down

0 comments on commit 1d0e827

Please sign in to comment.