diff --git a/src/pytorch_lightning/demos/boring_classes.py b/src/pytorch_lightning/demos/boring_classes.py index 31483db9bd53b..9967cdddf7029 100644 --- a/src/pytorch_lightning/demos/boring_classes.py +++ b/src/pytorch_lightning/demos/boring_classes.py @@ -156,11 +156,8 @@ def predict_dataloader(self) -> DataLoader: class BoringDataModule(LightningDataModule): - def __init__(self, data_dir: str = "./"): + def __init__(self) -> None: super().__init__() - self.data_dir = data_dir - self.non_picklable = None - self.checkpoint_state: Optional[str] = None self.random_full = RandomDataset(32, 64 * 4) def setup(self, stage: str) -> None: diff --git a/src/pytorch_lightning/loggers/csv_logs.py b/src/pytorch_lightning/loggers/csv_logs.py index 5b2d961bae11f..202f35676a363 100644 --- a/src/pytorch_lightning/loggers/csv_logs.py +++ b/src/pytorch_lightning/loggers/csv_logs.py @@ -26,6 +26,7 @@ from torch import Tensor +from lightning_lite.utilities.types import _PATH from pytorch_lightning.core.saving import save_hparams_to_yaml from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.logger import _add_prefix, _convert_params @@ -125,14 +126,14 @@ class CSVLogger(Logger): def __init__( self, - save_dir: str, + save_dir: _PATH, name: str = "lightning_logs", version: Optional[Union[int, str]] = None, prefix: str = "", flush_logs_every_n_steps: int = 100, ): super().__init__() - self._save_dir = save_dir + self._save_dir = os.fspath(save_dir) self._name = name or "" self._version = version self._prefix = prefix diff --git a/tests/tests_pytorch/core/test_datamodules.py b/tests/tests_pytorch/core/test_datamodules.py index d951f2ad55d9e..53493ed8cc103 100644 --- a/tests/tests_pytorch/core/test_datamodules.py +++ b/tests/tests_pytorch/core/test_datamodules.py @@ -126,18 +126,24 @@ def test_helper_boringdatamodule_with_verbose_setup(): dm.setup("test") +class DataDirDataModule(BoringDataModule): + def __init__(self, data_dir: str): + super().__init__() + self.data_dir = data_dir + + def test_dm_add_argparse_args(tmpdir): parser = ArgumentParser() - parser = BoringDataModule.add_argparse_args(parser) + parser = DataDirDataModule.add_argparse_args(parser) args = parser.parse_args(["--data_dir", str(tmpdir)]) assert args.data_dir == str(tmpdir) def test_dm_init_from_argparse_args(tmpdir): parser = ArgumentParser() - parser = BoringDataModule.add_argparse_args(parser) + parser = DataDirDataModule.add_argparse_args(parser) args = parser.parse_args(["--data_dir", str(tmpdir)]) - dm = BoringDataModule.from_argparse_args(args) + dm = DataDirDataModule.from_argparse_args(args) dm.prepare_data() dm.setup("fit") assert dm.data_dir == args.data_dir == str(tmpdir) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 79562e52e3fea..dd4b64767e4a9 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -11,11 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import glob import inspect import json import os from contextlib import contextmanager, ExitStack, redirect_stdout from io import StringIO +from pathlib import Path from typing import Callable, List, Optional, Union from unittest import mock from unittest.mock import ANY @@ -39,6 +41,7 @@ ) from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel from pytorch_lightning.loggers import _COMET_AVAILABLE, TensorBoardLogger +from pytorch_lightning.loggers.csv_logs import CSVLogger from pytorch_lightning.loggers.neptune import _NEPTUNE_AVAILABLE from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE from pytorch_lightning.strategies import DDPStrategy @@ -65,6 +68,19 @@ def mock_subclasses(baseclass, *subclasses): yield None +@pytest.fixture +def cleandir(tmp_path, monkeypatch): + monkeypatch.chdir(tmp_path) + yield + + +@pytest.fixture(autouse=True) +def ensure_cleandir(): + yield + # make sure tests don't leave configuration files + assert not glob.glob("*.yaml") + + @pytest.mark.parametrize("cli_args", [["--callbacks=1", "--logger"], ["--foo", "--bar=1"]]) def test_add_argparse_args_redefined_error(cli_args, monkeypatch): """Asserts error raised in case of passing not default cli arguments.""" @@ -132,7 +148,7 @@ def on_train_start(callback, trainer, _): assert hasattr(cli.trainer, "ran_asserts") and cli.trainer.ran_asserts -def test_lightning_cli_args_callbacks(tmpdir): +def test_lightning_cli_args_callbacks(cleandir): callbacks = [ dict( @@ -154,7 +170,7 @@ def on_fit_start(self): self.trainer.ran_asserts = True with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.callbacks={json.dumps(callbacks)}"]): - cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + cli = LightningCLI(TestModel, trainer_defaults=dict(fast_dev_run=True, logger=CSVLogger("."))) assert cli.trainer.ran_asserts @@ -168,7 +184,7 @@ def test_lightning_cli_single_arg_callback(): @pytest.mark.parametrize("run", (False, True)) -def test_lightning_cli_configurable_callbacks(tmpdir, run): +def test_lightning_cli_configurable_callbacks(cleandir, run): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_lightning_class_args(LearningRateMonitor, "learning_rate_monitor") @@ -177,7 +193,7 @@ def fit(self, **_): pass cli_args = ["fit"] if run else [] - cli_args += [f"--trainer.default_root_dir={tmpdir}", "--learning_rate_monitor.logging_interval=epoch"] + cli_args += ["--learning_rate_monitor.logging_interval=epoch"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel, run=run) @@ -187,7 +203,7 @@ def fit(self, **_): assert callback[0].logging_interval == "epoch" -def test_lightning_cli_args_cluster_environments(tmpdir): +def test_lightning_cli_args_cluster_environments(cleandir): plugins = [dict(class_path="lightning_lite.plugins.environments.SLURMEnvironment")] class TestModel(BoringModel): @@ -197,26 +213,32 @@ def on_fit_start(self): self.trainer.ran_asserts = True with mock.patch("sys.argv", ["any.py", "fit", f"--trainer.plugins={json.dumps(plugins)}"]): - cli = LightningCLI(TestModel, trainer_defaults=dict(default_root_dir=str(tmpdir), fast_dev_run=True)) + cli = LightningCLI(TestModel, trainer_defaults=dict(fast_dev_run=True)) assert cli.trainer.ran_asserts -def test_lightning_cli_args(tmpdir): +class DataDirDataModule(BoringDataModule): + def __init__(self, data_dir): + super().__init__() + +def test_lightning_cli_args(cleandir): cli_args = [ "fit", - f"--data.data_dir={tmpdir}", - f"--trainer.default_root_dir={tmpdir}", + "--data.data_dir=.", "--trainer.max_epochs=1", + "--trainer.limit_train_batches=1", + "--trainer.limit_val_batches=0", "--trainer.enable_model_summary=False", + "--trainer.logger=False", "--seed_everything=1234", ] with mock.patch("sys.argv", ["any.py"] + cli_args): - cli = LightningCLI(BoringModel, BoringDataModule, trainer_defaults={"callbacks": [LearningRateMonitor()]}) + cli = LightningCLI(BoringModel, DataDirDataModule) - config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" + config_path = "config.yaml" assert os.path.isfile(config_path) with open(config_path) as f: loaded_config = yaml.safe_load(f.read()) @@ -228,10 +250,10 @@ def test_lightning_cli_args(tmpdir): assert loaded_config["trainer"] == cli_config["trainer"] -def test_lightning_env_parse(tmpdir): +def test_lightning_env_parse(cleandir): out = StringIO() with mock.patch("sys.argv", ["", "fit", "--help"]), redirect_stdout(out), pytest.raises(SystemExit): - LightningCLI(BoringModel, BoringDataModule, env_parse=True) + LightningCLI(BoringModel, DataDirDataModule, env_parse=True) out = out.getvalue() assert "PL_FIT__CONFIG" in out assert "PL_FIT__SEED_EVERYTHING" in out @@ -240,23 +262,23 @@ def test_lightning_env_parse(tmpdir): assert "PL_FIT__CKPT_PATH" in out env_vars = { - "PL_FIT__DATA__DATA_DIR": str(tmpdir), - "PL_FIT__TRAINER__DEFAULT_ROOT_DIR": str(tmpdir), + "PL_FIT__DATA__DATA_DIR": ".", + "PL_FIT__TRAINER__DEFAULT_ROOT_DIR": ".", "PL_FIT__TRAINER__MAX_EPOCHS": "1", - "PL_FIT__TRAINER__LOGGER": "false", + "PL_FIT__TRAINER__LOGGER": "False", } with mock.patch.dict(os.environ, env_vars), mock.patch("sys.argv", ["", "fit"]): - cli = LightningCLI(BoringModel, BoringDataModule, env_parse=True) - assert cli.config.fit.data.data_dir == str(tmpdir) - assert cli.config.fit.trainer.default_root_dir == str(tmpdir) + cli = LightningCLI(BoringModel, DataDirDataModule, env_parse=True) + assert cli.config.fit.data.data_dir == "." + assert cli.config.fit.trainer.default_root_dir == "." assert cli.config.fit.trainer.max_epochs == 1 assert cli.config.fit.trainer.logger is False -def test_lightning_cli_save_config_cases(tmpdir): +def test_lightning_cli_save_config_cases(cleandir): - config_path = tmpdir / "config.yaml" - cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.logger=False", "--trainer.fast_dev_run=1"] + config_path = "config.yaml" + cli_args = ["fit", "--trainer.logger=false", "--trainer.fast_dev_run=1"] # With fast_dev_run!=False config should not be saved with mock.patch("sys.argv", ["any.py"] + cli_args): @@ -274,9 +296,9 @@ def test_lightning_cli_save_config_cases(tmpdir): LightningCLI(BoringModel) -def test_lightning_cli_save_config_only_once(tmpdir): - config_path = tmpdir / "config.yaml" - cli_args = [f"--trainer.default_root_dir={tmpdir}", "--trainer.logger=False", "--trainer.max_epochs=1"] +def test_lightning_cli_save_config_only_once(cleandir): + config_path = "config.yaml" + cli_args = ["--trainer.logger=false", "--trainer.max_epochs=1"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(BoringModel, run=False) @@ -290,31 +312,33 @@ def test_lightning_cli_save_config_only_once(tmpdir): cli.trainer.test(cli.model) # Should not fail because config already saved -def test_lightning_cli_config_and_subclass_mode(tmpdir): +def test_lightning_cli_config_and_subclass_mode(cleandir): input_config = { "fit": { "model": {"class_path": "pytorch_lightning.demos.boring_classes.BoringModel"}, "data": { - "class_path": "pytorch_lightning.demos.boring_classes.BoringDataModule", - "init_args": {"data_dir": str(tmpdir)}, + "class_path": "DataDirDataModule", + "init_args": {"data_dir": "."}, }, - "trainer": {"default_root_dir": str(tmpdir), "max_epochs": 1, "enable_model_summary": False}, + "trainer": {"max_epochs": 1, "enable_model_summary": False, "logger": False}, } } - config_path = tmpdir / "config.yaml" + config_path = "config.yaml" with open(config_path, "w") as f: f.write(yaml.dump(input_config)) - with mock.patch("sys.argv", ["any.py", "--config", str(config_path)]): + with mock.patch("sys.argv", ["any.py", "--config", config_path]), mock_subclasses( + LightningDataModule, DataDirDataModule + ): cli = LightningCLI( BoringModel, BoringDataModule, subclass_mode_model=True, subclass_mode_data=True, - trainer_defaults={"callbacks": LearningRateMonitor()}, + save_config_kwargs={"overwrite": True}, ) - config_path = tmpdir / "lightning_logs" / "version_0" / "config.yaml" + config_path = "config.yaml" assert os.path.isfile(config_path) with open(config_path) as f: loaded_config = yaml.safe_load(f.read()) @@ -330,7 +354,6 @@ def any_model_any_data_cli(): def test_lightning_cli_help(): - cli_args = ["any.py", "fit", "--help"] out = StringIO() with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): @@ -348,9 +371,11 @@ def test_lightning_cli_help(): if param not in skip_params: assert f"--trainer.{param}" in out - cli_args = ["any.py", "fit", "--data.help=pytorch_lightning.demos.boring_classes.BoringDataModule"] + cli_args = ["any.py", "fit", "--data.help=DataDirDataModule"] out = StringIO() - with mock.patch("sys.argv", cli_args), redirect_stdout(out), pytest.raises(SystemExit): + with mock.patch("sys.argv", cli_args), redirect_stdout(out), mock_subclasses( + LightningDataModule, DataDirDataModule + ), pytest.raises(SystemExit): any_model_any_data_cli() assert "--data.init_args.data_dir" in out.getvalue() @@ -380,7 +405,7 @@ def test_lightning_cli_print_config(): assert outval["ckpt_path"] is None -def test_lightning_cli_submodules(tmpdir): +def test_lightning_cli_submodules(cleandir): class MainModule(BoringModel): def __init__(self, submodule1: LightningModule, submodule2: LightningModule, main_param: int = 1): super().__init__() @@ -394,12 +419,10 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai submodule2: class_path: pytorch_lightning.demos.boring_classes.BoringModel """ - config_path = tmpdir / "config.yaml" - with open(config_path, "w") as f: - f.write(config) - - cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"] + config_path = Path("config.yaml") + config_path.write_text(config) + cli_args = [f"--config={config_path}"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(MainModule, run=False) @@ -409,7 +432,7 @@ def __init__(self, submodule1: LightningModule, submodule2: LightningModule, mai @pytest.mark.skipif(not _TORCHVISION_AVAILABLE, reason=str(_TORCHVISION_AVAILABLE)) -def test_lightning_cli_torch_modules(tmpdir): +def test_lightning_cli_torch_modules(cleandir): class TestModule(BoringModel): def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[torch.nn.Module]] = None): super().__init__() @@ -429,12 +452,10 @@ def __init__(self, activation: torch.nn.Module = None, transform: Optional[List[ init_args: size: 64 """ - config_path = tmpdir / "config.yaml" - with open(config_path, "w") as f: - f.write(config) - - cli_args = [f"--trainer.default_root_dir={tmpdir}", f"--config={str(config_path)}"] + config_path = Path("config.yaml") + config_path.write_text(config) + cli_args = [f"--config={config_path}"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = LightningCLI(TestModule, run=False) @@ -458,13 +479,13 @@ def __init__(self, batch_size: int = 8): self.num_classes = 5 # only available after instantiation -def test_lightning_cli_link_arguments(tmpdir): +def test_lightning_cli_link_arguments(): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.link_arguments("data.batch_size", "model.batch_size") parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate") - cli_args = [f"--trainer.default_root_dir={tmpdir}", "--data.batch_size=12"] + cli_args = ["--data.batch_size=12"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False) @@ -495,9 +516,9 @@ def on_fit_start(self): # mps not yet supported by distributed @RunIf(skip_windows=True, mps=False) -@pytest.mark.parametrize("logger", (False, True)) +@pytest.mark.parametrize("logger", (None, TensorBoardLogger("."))) @pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp")) -def test_cli_distributed_save_config_callback(tmpdir, logger, strategy): +def test_cli_distributed_save_config_callback(cleandir, logger, strategy): from torch.multiprocessing import ProcessRaisedException with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( @@ -506,7 +527,6 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, strategy): LightningCLI( EarlyExitTestModel, trainer_defaults={ - "default_root_dir": str(tmpdir), "logger": logger, "max_steps": 1, "max_epochs": 1, @@ -516,17 +536,17 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, strategy): }, ) if logger: - config_dir = tmpdir / "lightning_logs" + config_dir = Path("lightning_logs") # no more version dirs should get created assert os.listdir(config_dir) == ["version_0"] config_path = config_dir / "version_0" / "config.yaml" else: - config_path = tmpdir / "config.yaml" + config_path = "config.yaml" assert os.path.isfile(config_path) -def test_cli_config_overwrite(tmpdir): - trainer_defaults = {"default_root_dir": str(tmpdir), "logger": False, "max_steps": 1, "max_epochs": 1} +def test_cli_config_overwrite(cleandir): + trainer_defaults = {"max_steps": 1, "max_epochs": 1, "logger": False} argv = ["any.py", "fit"] with mock.patch("sys.argv", argv): @@ -538,13 +558,13 @@ def test_cli_config_overwrite(tmpdir): @pytest.mark.parametrize("run", (False, True)) -def test_lightning_cli_optimizer(tmpdir, run): +def test_lightning_cli_optimizer(run): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) match = "BoringModel.configure_optimizers` will be overridden by " "`MyLightningCLI.configure_optimizers`" - argv = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1"] if run else [] + argv = ["fit", "--trainer.fast_dev_run=1"] if run else [] with mock.patch("sys.argv", ["any.py"] + argv), pytest.warns(UserWarning, match=match): cli = MyLightningCLI(BoringModel, run=run) @@ -559,13 +579,13 @@ def add_arguments_to_parser(self, parser): assert len(cli.trainer.lr_scheduler_configs) == 0 -def test_lightning_cli_optimizer_and_lr_scheduler(tmpdir): +def test_lightning_cli_optimizer_and_lr_scheduler(): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args(torch.optim.Adam) parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) - cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.fast_dev_run=1", "--lr_scheduler.gamma=0.8"] + cli_args = ["fit", "--trainer.fast_dev_run=1", "--lr_scheduler.gamma=0.8"] with mock.patch("sys.argv", ["any.py"] + cli_args): cli = MyLightningCLI(BoringModel) @@ -578,7 +598,7 @@ def add_arguments_to_parser(self, parser): assert cli.trainer.lr_scheduler_configs[0].scheduler.gamma == 0.8 -def test_cli_no_need_configure_optimizers(): +def test_cli_no_need_configure_optimizers(cleandir): class BoringModel(LightningModule): def __init__(self): super().__init__() @@ -605,7 +625,7 @@ def train_dataloader(self): verify.assert_called_once_with(cli.trainer, cli.model) -def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(tmpdir): +def test_lightning_cli_optimizer_and_lr_scheduler_subclasses(cleandir): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) @@ -615,7 +635,6 @@ def add_arguments_to_parser(self, parser): lr_scheduler_arg = dict(class_path="torch.optim.lr_scheduler.StepLR", init_args=dict(step_size=50)) cli_args = [ "fit", - f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1", f"--optimizer={json.dumps(optimizer_arg)}", f"--lr_scheduler={json.dumps(lr_scheduler_arg)}", @@ -632,7 +651,7 @@ def add_arguments_to_parser(self, parser): @pytest.mark.parametrize("use_generic_base_class", [False, True]) -def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base_class, tmpdir): +def test_lightning_cli_optimizers_and_lr_scheduler_with_link_to(use_generic_base_class): class MyLightningCLI(LightningCLI): def add_arguments_to_parser(self, parser): parser.add_optimizer_args( @@ -653,7 +672,7 @@ def __init__(self, optim1: dict, optim2: dict, scheduler: dict): self.optim2 = instantiate_class(self.parameters(), optim2) self.scheduler = instantiate_class(self.optim1, scheduler) - cli_args = ["fit", f"--trainer.default_root_dir={tmpdir}", "--trainer.max_epochs=1"] + cli_args = ["fit", "--trainer.fast_dev_run=1"] if use_generic_base_class: cli_args += [ "--optim1", @@ -783,7 +802,7 @@ def subcommands(): assert "The y (type: float, default: 1.0)" in out -def test_lightning_cli_run(): +def test_lightning_cli_run(cleandir): with mock.patch("sys.argv", ["any.py"]): cli = LightningCLI(BoringModel, run=False) assert cli.trainer.global_step == 0 @@ -917,7 +936,7 @@ def test_callbacks_append(use_class_path_callbacks): assert all(t in callback_types for t in expected) -def test_optimizers_and_lr_schedulers_reload(tmpdir): +def test_optimizers_and_lr_schedulers_reload(cleandir): base = ["any.py", "--trainer.max_epochs=1"] input = base + [ "--lr_scheduler", @@ -942,13 +961,13 @@ def test_optimizers_and_lr_schedulers_reload(tmpdir): assert dict_config["lr_scheduler"]["class_path"] == "torch.optim.lr_scheduler.OneCycleLR" # reload config - yaml_config_file = tmpdir / "config.yaml" - yaml_config_file.write_text(yaml_config, "utf-8") + yaml_config_file = Path("config.yaml") + yaml_config_file.write_text(yaml_config) with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): LightningCLI(BoringModel, run=False) -def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(tmpdir): +def test_optimizers_and_lr_schedulers_add_arguments_to_parser_implemented_reload(cleandir): class TestLightningCLI(LightningCLI): def __init__(self, *args): super().__init__(*args, run=False) @@ -1007,8 +1026,8 @@ def __init__(self, opt1_config: dict, opt2_config: dict, sch_config: dict): assert dict_config["something"] == ["a", "b", "c"] # reload config - yaml_config_file = tmpdir / "config.yaml" - yaml_config_file.write_text(yaml_config, "utf-8") + yaml_config_file = Path("config.yaml") + yaml_config_file.write_text(yaml_config) with mock.patch("sys.argv", base + [f"--config={yaml_config_file}"]): cli = TestLightningCLI(TestModel) @@ -1106,14 +1125,14 @@ def test_lightning_cli_config_before_and_after_subcommand(): assert cli.trainer.fast_dev_run == 1 -def test_lightning_cli_parse_kwargs_with_subcommands(tmpdir): +def test_lightning_cli_parse_kwargs_with_subcommands(cleandir): fit_config = {"trainer": {"limit_train_batches": 2}} - fit_config_path = tmpdir / "fit.yaml" - fit_config_path.write_text(str(fit_config), "utf8") + fit_config_path = Path("fit.yaml") + fit_config_path.write_text(str(fit_config)) validate_config = {"trainer": {"limit_val_batches": 3}} - validate_config_path = tmpdir / "validate.yaml" - validate_config_path.write_text(str(validate_config), "utf8") + validate_config_path = Path("validate.yaml") + validate_config_path.write_text(str(validate_config)) parser_kwargs = { "fit": {"default_config_files": [str(fit_config_path)]}, @@ -1137,15 +1156,15 @@ def test_lightning_cli_parse_kwargs_with_subcommands(tmpdir): assert cli.trainer.limit_val_batches == 3 -def test_lightning_cli_subcommands_common_default_config_files(tmpdir): +def test_lightning_cli_subcommands_common_default_config_files(cleandir): class Model(BoringModel): def __init__(self, foo: int, *args, **kwargs): super().__init__(*args, **kwargs) self.foo = foo config = {"fit": {"model": {"foo": 123}}} - config_path = tmpdir / "default.yaml" - config_path.write_text(str(config), "utf8") + config_path = Path("default.yaml") + config_path.write_text(str(config)) parser_kwargs = {"default_config_files": [str(config_path)]} with mock.patch("sys.argv", ["any.py", "fit"]), mock.patch( diff --git a/tests/tests_pytorch/tuner/test_scale_batch_size.py b/tests/tests_pytorch/tuner/test_scale_batch_size.py index 643899f05b240..8d51b8ac875d9 100644 --- a/tests/tests_pytorch/tuner/test_scale_batch_size.py +++ b/tests/tests_pytorch/tuner/test_scale_batch_size.py @@ -29,8 +29,8 @@ class BatchSizeDataModule(BoringDataModule): - def __init__(self, data_dir, batch_size): - super().__init__(data_dir) + def __init__(self, batch_size): + super().__init__() if batch_size is not None: self.batch_size = batch_size @@ -63,7 +63,7 @@ def test_scale_batch_size_method_with_model_or_datamodule(tmpdir, model_bs, dm_b trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=0, max_epochs=1) model = BatchSizeModel(model_bs) - datamodule = BatchSizeDataModule(tmpdir, dm_bs) if dm_bs != -1 else None + datamodule = BatchSizeDataModule(dm_bs) if dm_bs != -1 else None new_batch_size = trainer.tuner.scale_batch_size( model, mode="binsearch", init_val=4, max_trials=2, datamodule=datamodule @@ -170,8 +170,8 @@ def test_auto_scale_batch_size_set_datamodule_attribute(tmpdir, use_hparams): before_batch_size = hparams["batch_size"] class HparamsBatchSizeDataModule(BoringDataModule): - def __init__(self, data_dir, batch_size): - super().__init__(data_dir) + def __init__(self, batch_size): + super().__init__() self.save_hyperparameters() def train_dataloader(self): @@ -181,7 +181,7 @@ def val_dataloader(self): return DataLoader(RandomDataset(32, 64), batch_size=self.hparams.batch_size) datamodule_class = HparamsBatchSizeDataModule if use_hparams else BatchSizeDataModule - datamodule = datamodule_class(data_dir=tmpdir, batch_size=before_batch_size) + datamodule = datamodule_class(batch_size=before_batch_size) model = BatchSizeModel(**hparams) trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)