Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleaner datadir management for some tests #15791

Merged
merged 3 commits into from
Nov 25, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions src/pytorch_lightning/demos/boring_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,8 @@ def predict_dataloader(self) -> DataLoader:


class BoringDataModule(LightningDataModule):
def __init__(self, data_dir: str = "./"):
def __init__(self) -> None:
super().__init__()
self.data_dir = data_dir
self.non_picklable = None
self.checkpoint_state: Optional[str] = None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.random_full = RandomDataset(32, 64 * 4)

def setup(self, stage: str) -> None:
Expand Down
5 changes: 3 additions & 2 deletions src/pytorch_lightning/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

from torch import Tensor

from lightning_lite.utilities.types import _PATH
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params
Expand Down Expand Up @@ -125,14 +126,14 @@ class CSVLogger(Logger):

def __init__(
self,
save_dir: str,
save_dir: _PATH,
name: str = "lightning_logs",
version: Optional[Union[int, str]] = None,
prefix: str = "",
flush_logs_every_n_steps: int = 100,
):
super().__init__()
self._save_dir = save_dir
self._save_dir = os.fspath(save_dir)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._name = name or ""
self._version = version
self._prefix = prefix
Expand Down
12 changes: 9 additions & 3 deletions tests/tests_pytorch/core/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,18 +126,24 @@ def test_helper_boringdatamodule_with_verbose_setup():
dm.setup("test")


class DataDirDataModule(BoringDataModule):
def __init__(self, data_dir: str):
super().__init__()
self.data_dir = data_dir


def test_dm_add_argparse_args(tmpdir):
parser = ArgumentParser()
parser = BoringDataModule.add_argparse_args(parser)
parser = DataDirDataModule.add_argparse_args(parser)
args = parser.parse_args(["--data_dir", str(tmpdir)])
assert args.data_dir == str(tmpdir)


def test_dm_init_from_argparse_args(tmpdir):
parser = ArgumentParser()
parser = BoringDataModule.add_argparse_args(parser)
parser = DataDirDataModule.add_argparse_args(parser)
args = parser.parse_args(["--data_dir", str(tmpdir)])
dm = BoringDataModule.from_argparse_args(args)
dm = DataDirDataModule.from_argparse_args(args)
dm.prepare_data()
dm.setup("fit")
assert dm.data_dir == args.data_dir == str(tmpdir)
Expand Down
Loading