Skip to content

Commit

Permalink
Update dict data loader train tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mlahariya committed Dec 5, 2024
1 parent 466d77d commit dddad21
Showing 1 changed file with 60 additions and 0 deletions.
60 changes: 60 additions & 0 deletions tests/ml_tools/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,15 @@ def dataloader(batch_size: int = 25) -> DataLoader:
y = torch.sin(x)
return to_dataloader(x, y, batch_size=batch_size, infinite=True)

def dictdataloader(data_configs: dict[str, dict[str, int]]) -> DictDataLoader:
dls = {}
for name, config in data_configs.items():
data_size = config["data_size"]
batch_size = config["batch_size"]
x = torch.rand(data_size, 1)
y = torch.sin(x)
dls[name] = to_dataloader(x, y, batch_size=batch_size, infinite=True)
return DictDataLoader(dls)

def train_val_dataloaders(batch_size: int = 25) -> tuple:
x = torch.rand(batch_size, 1)
Expand Down Expand Up @@ -310,3 +319,54 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
# but that is time-consuming since training must be run twice for comparison.
# The below check may be plausible enough.
assert len(files) == 1 # Since only the best checkpoint must be stored.

def test_dict_dataloader_with_trainer(tmp_path: Path, Basic: torch.nn.Module) -> None:
data_configs = {
"dataset1": {"data_size": 30, "batch_size": 5},
"dataset2": {"data_size": 50, "batch_size": 10},
}
dict_loader = dictdataloader(data_configs)

# Define the model, loss function, optimizer
model = Basic
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def loss_fn(model : torch.nn.Module, data : dict) -> tuple[torch.Tensor, dict]:
losses = []
for key, (x, y) in data.items():
out = model(x)
loss = criterion(out, y)
losses.append(loss)
total_loss = sum(losses) / len(losses)
return total_loss, {}

config = TrainConfig(
root_folder=tmp_path,
max_iter=50,
checkpoint_every=10,
write_every=10,
)

trainer = Trainer(model, optimizer, config, loss_fn, dict_loader)
with trainer.enable_grad_opt():
trainer.fit()

x = torch.rand(5, 1)
for key in dict_loader.dataloaders.keys():
y_pred = model(x)
assert y_pred.shape == (5, 1)

def test_dict_dataloader() -> None:
data_configs = {
"dataset1": {"data_size": 20, "batch_size": 5},
"dataset2": {"data_size": 40, "batch_size": 10},
}
ddl = dictdataloader(data_configs)
assert set(ddl.dataloaders.keys()) == {"dataset1", "dataset2"}

batch = next(iter(ddl))
assert batch["dataset1"][0].shape == (5, 1)
assert batch["dataset2"][0].shape == (10, 1)
for key, (x, y) in batch.items():
assert x.shape == y.shape

0 comments on commit dddad21

Please sign in to comment.