Skip to content

Commit

Permalink
Fix TensorBoardLogger test on Windows (#19824)
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli authored Apr 29, 2024
1 parent 49ed2b1 commit d194976
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion requirements/fabric/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

numpy >=1.17.2, <1.27.0
torch >=2.0.0, <2.4.0
fsspec[http] >=2022.5.0, <2023.11.0
fsspec[http] >=2022.5.0, <2024.4.0
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
lightning-utilities >=0.8.0, <0.12.0
2 changes: 1 addition & 1 deletion requirements/pytorch/base.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ numpy >=1.17.2, <1.27.0
torch >=2.0.0, <2.4.0
tqdm >=4.57.0, <4.67.0
PyYAML >=5.4, <6.1.0
fsspec[http] >=2022.5.0, <2023.11.0
fsspec[http] >=2022.5.0, <2024.4.0
torchmetrics >=0.7.0, <1.3.0 # needed for using fixed compare_version
packaging >=20.0, <=23.1
typing-extensions >=4.4.0, <4.10.0
Expand Down
18 changes: 9 additions & 9 deletions tests/tests_pytorch/loggers/test_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def test_tensorboard_no_name(tmp_path, name):
assert os.listdir(tmp_path / "version_0")


@mock.patch.dict(os.environ, {}, clear=True)
def test_tensorboard_log_sub_dir(tmp_path):
class TestLogger(TensorBoardLogger):
# for reproducibility
Expand Down Expand Up @@ -141,14 +140,15 @@ def name(self):
trainer = Trainer(**trainer_args, logger=logger)
assert trainer.logger.log_dir == os.path.join(explicit_save_dir, "name", "version", "sub_dir")

# test env var (`$`) handling
test_env_dir = "some_directory"
os.environ["TEST_ENV_DIR"] = test_env_dir
save_dir = "$TEST_ENV_DIR/tmp"
explicit_save_dir = f"{test_env_dir}/tmp"
logger = TestLogger(save_dir, sub_dir="sub_dir")
trainer = Trainer(**trainer_args, logger=logger)
assert trainer.logger.log_dir == os.path.join(explicit_save_dir, "name", "version", "sub_dir")
with mock.patch.dict(os.environ, {}):
# test env var (`$`) handling
test_env_dir = "some_directory"
os.environ["TEST_ENV_DIR"] = test_env_dir
save_dir = "$TEST_ENV_DIR/tmp"
explicit_save_dir = f"{test_env_dir}/tmp"
logger = TestLogger(save_dir, sub_dir="sub_dir")
trainer = Trainer(**trainer_args, logger=logger)
assert trainer.logger.log_dir == os.path.join(explicit_save_dir, "name", "version", "sub_dir")


@pytest.mark.parametrize("step_idx", [10, None])
Expand Down

0 comments on commit d194976

Please sign in to comment.