Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
mlahariya committed Dec 5, 2024
1 parent dddad21 commit 0411419
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
10 changes: 6 additions & 4 deletions qadence/ml_tools/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,10 +604,12 @@ def _modify_batch_end_loss_metrics(
updated_metrics[f"{phase}_loss"] = loss
return loss, updated_metrics
return loss_metrics
def _reset_model_and_opt(self):

def _reset_model_and_opt(self) -> None:
"""
Save model_old and optimizer_old for epoch 0. This allows us to create a copy of model
Save model_old and optimizer_old for epoch 0.
This allows us to create a copy of model
and optimizer before running the optimization.
We do this because optimize step provides loss, metrics
Expand All @@ -621,7 +623,7 @@ def _reset_model_and_opt(self):
# Deep copy model and optimizer to maintain checkpoints
self.model_old = copy.deepcopy(self.model)
self.optimizer_old = copy.deepcopy(self.optimizer)
except:
except Exception:
self.model_old = self.model
self.optimizer_old = self.optimizer

Expand Down
12 changes: 8 additions & 4 deletions tests/ml_tools/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def dataloader(batch_size: int = 25) -> DataLoader:
y = torch.sin(x)
return to_dataloader(x, y, batch_size=batch_size, infinite=True)


def dictdataloader(data_configs: dict[str, dict[str, int]]) -> DictDataLoader:
dls = {}
for name, config in data_configs.items():
Expand All @@ -31,6 +32,7 @@ def dictdataloader(data_configs: dict[str, dict[str, int]]) -> DictDataLoader:
dls[name] = to_dataloader(x, y, batch_size=batch_size, infinite=True)
return DictDataLoader(dls)


def train_val_dataloaders(batch_size: int = 25) -> tuple:
x = torch.rand(batch_size, 1)
y = torch.sin(x)
Expand Down Expand Up @@ -320,6 +322,7 @@ def loss_fn(model: torch.nn.Module, data: torch.Tensor) -> tuple[torch.Tensor, d
# The below check may be plausible enough.
assert len(files) == 1 # Since only the best checkpoint must be stored.


def test_dict_dataloader_with_trainer(tmp_path: Path, Basic: torch.nn.Module) -> None:
data_configs = {
"dataset1": {"data_size": 30, "batch_size": 5},
Expand All @@ -332,7 +335,7 @@ def test_dict_dataloader_with_trainer(tmp_path: Path, Basic: torch.nn.Module) ->
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def loss_fn(model : torch.nn.Module, data : dict) -> tuple[torch.Tensor, dict]:
def loss_fn(model: torch.nn.Module, data: dict) -> tuple[torch.Tensor, dict]:
losses = []
for key, (x, y) in data.items():
out = model(x)
Expand All @@ -357,6 +360,7 @@ def loss_fn(model : torch.nn.Module, data : dict) -> tuple[torch.Tensor, dict]:
y_pred = model(x)
assert y_pred.shape == (5, 1)


def test_dict_dataloader() -> None:
data_configs = {
"dataset1": {"data_size": 20, "batch_size": 5},
Expand All @@ -366,7 +370,7 @@ def test_dict_dataloader() -> None:
assert set(ddl.dataloaders.keys()) == {"dataset1", "dataset2"}

batch = next(iter(ddl))
assert batch["dataset1"][0].shape == (5, 1)
assert batch["dataset2"][0].shape == (10, 1)
assert batch["dataset1"][0].shape == (5, 1)
assert batch["dataset2"][0].shape == (10, 1)
for key, (x, y) in batch.items():
assert x.shape == y.shape
assert x.shape == y.shape

0 comments on commit 0411419

Please sign in to comment.