From dddad21b38bbdcb9fc16777febaec5ca9731271f Mon Sep 17 00:00:00 2001 From: mlahariya <40852060+mlahariya@users.noreply.github.com> Date: Thu, 5 Dec 2024 12:52:59 +0100 Subject: [PATCH] Update dict data loader train tests --- tests/ml_tools/test_train.py | 60 ++++++++++++++++++++++++++++++++++++ 1 file changed, 60 insertions(+) diff --git a/tests/ml_tools/test_train.py b/tests/ml_tools/test_train.py index 52daa5b7..0af82d67 100644 --- a/tests/ml_tools/test_train.py +++ b/tests/ml_tools/test_train.py @@ -21,6 +21,15 @@ def dataloader(batch_size: int = 25) -> DataLoader: y = torch.sin(x) return to_dataloader(x, y, batch_size=batch_size, infinite=True) +def dictdataloader(data_configs: dict[str, dict[str, int]]) -> DictDataLoader: + dls = {} + for name, config in data_configs.items(): + data_size = config["data_size"] + batch_size = config["batch_size"] + x = torch.rand(data_size, 1) + y = torch.sin(x) + dls[name] = to_dataloader(x, y, batch_size=batch_size, infinite=True) + return DictDataLoader(dls) def train_val_dataloaders(batch_size: int = 25) -> tuple: x = torch.rand(batch_size, 1) @@ -310,3 +319,54 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d # but that is time-consuming since training must be run twice for comparison. # The below check may be plausible enough. assert len(files) == 1 # Since only the best checkpoint must be stored. + +def test_dict_dataloader_with_trainer(tmp_path: Path, Basic: torch.nn.Module) -> None: + data_configs = { + "dataset1": {"data_size": 30, "batch_size": 5}, + "dataset2": {"data_size": 50, "batch_size": 10}, + } + dict_loader = dictdataloader(data_configs) + + # Define the model, loss function, optimizer + model = Basic + criterion = torch.nn.MSELoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + def loss_fn(model : torch.nn.Module, data : dict) -> tuple[torch.Tensor, dict]: + losses = [] + for key, (x, y) in data.items(): + out = model(x) + loss = criterion(out, y) + losses.append(loss) + total_loss = sum(losses) / len(losses) + return total_loss, {} + + config = TrainConfig( + root_folder=tmp_path, + max_iter=50, + checkpoint_every=10, + write_every=10, + ) + + trainer = Trainer(model, optimizer, config, loss_fn, dict_loader) + with trainer.enable_grad_opt(): + trainer.fit() + + x = torch.rand(5, 1) + for key in dict_loader.dataloaders.keys(): + y_pred = model(x) + assert y_pred.shape == (5, 1) + +def test_dict_dataloader() -> None: + data_configs = { + "dataset1": {"data_size": 20, "batch_size": 5}, + "dataset2": {"data_size": 40, "batch_size": 10}, + } + ddl = dictdataloader(data_configs) + assert set(ddl.dataloaders.keys()) == {"dataset1", "dataset2"} + + batch = next(iter(ddl)) + assert batch["dataset1"][0].shape == (5, 1) + assert batch["dataset2"][0].shape == (10, 1) + for key, (x, y) in batch.items(): + assert x.shape == y.shape \ No newline at end of file