Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save and load task state to the checkpoint #5328

Merged
merged 1 commit into from
Sep 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion fairseq/checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,20 @@ def is_better(a, b):
"checkpoint_last{}.pt".format(suffix)
] = not cfg.no_last_checkpoints

extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss}
extra_state = {
"train_iterator": epoch_itr.state_dict(),
"val_loss": val_loss,
}

# Going forward, different tasks could expose an API like this to dump all
# the checkpoint worthy attributes in a dictionary which then will be
# merged with the parent dictionary to create the "extra_state". This
# allows for an extensible yet simple design to checkpoint task level
# attributes
if hasattr(trainer.task, "get_checkpoint_dict"):
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
logger.info(f"{trainer.task.__class__} checkpoint worthy attributes are ready to be persisted with the checkpoint")

if hasattr(save_checkpoint, "best"):
extra_state.update({"best": save_checkpoint.best})

Expand Down Expand Up @@ -275,6 +288,11 @@ def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
)
epoch_itr.load_state_dict(itr_state)

# Preload the observer stats for Supernet
supernet_cp_dict = extra_state.get("supernet", {})
if supernet_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
trainer.task.set_checkpoint_dict(supernet_cp_dict)
else:
epoch_itr = trainer.get_train_iterator(
epoch=1, load_dataset=True, **passthrough_args
Expand Down
2 changes: 0 additions & 2 deletions tests/test_checkpoint_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
from io import StringIO
from unittest.mock import patch

from omegaconf import OmegaConf

from fairseq import checkpoint_utils
from tests.utils import (
create_dummy_data,
Expand Down
172 changes: 172 additions & 0 deletions tests/test_checkpoint_utils_for_task_level_attributes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
#!/usr/bin/env fbpython
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

import contextlib
import logging
import unittest
from io import StringIO
from unittest.mock import MagicMock, patch

import torch
from fairseq import checkpoint_utils, data
from omegaconf import OmegaConf


def mock_trainer(epoch, num_updates, iterations_in_epoch):
trainer = MagicMock()
trainer.load_checkpoint.return_value = {
"train_iterator": {
"epoch": epoch,
"iterations_in_epoch": iterations_in_epoch,
"shuffle": False,
},
"supernet": checkpoint_dict()["supernet"],
}
trainer.get_num_updates.return_value = num_updates
trainer.task.get_checkpoint_dict.return_value = checkpoint_dict()
trainer.task.set_checkpoint_dict = MagicMock()

return trainer


def checkpoint_dict():
return {
"supernet": {
"observer_stats": {
(
4,
16,
"MovingAveragePerChannelMinMax",
"MovingAveragePerChannelMinMax",
): {"mod1": 1, "mod2": 2, "mod3": 3}
}
}
}


def mock_dict():
d = MagicMock()
d.pad.return_value = 1
d.eos.return_value = 2
d.unk.return_value = 3
return d


def get_trainer_and_epoch_itr(epoch, epoch_size, num_updates, iterations_in_epoch):
tokens = torch.LongTensor(list(range(epoch_size))).view(1, -1)
tokens_ds = data.TokenBlockDataset(
tokens,
sizes=[tokens.size(-1)],
block_size=1,
pad=0,
eos=1,
include_targets=False,
)
trainer = mock_trainer(epoch, num_updates, iterations_in_epoch)
dataset = data.LanguagePairDataset(
tokens_ds, tokens_ds.sizes, mock_dict(), shuffle=False
)
epoch_itr = data.EpochBatchIterator(
dataset=dataset,
collate_fn=dataset.collater,
batch_sampler=[[i] for i in range(epoch_size)],
)
return trainer, epoch_itr


def get_mock_cfg(finetune_from_model):
cfg_mock = OmegaConf.create(
{
"checkpoint": {
"save_dir": None,
"optimizer_overrides": "{}",
"reset_dataloader": False,
"reset_meters": False,
"reset_optimizer": False,
"reset_lr_scheduler": False,
"finetune_from_model": finetune_from_model,
"model_parallel_size": 1,
"restore_file": "checkpoint_last.pt",
"no_save": False,
"save_interval_updates": 0,
"no_last_checkpoints": False,
"keep_interval_updates": 0,
"keep_last_epochs": 0,
"keep_best_checkpoints": 0,
},
"common": {
"model_parallel_size": 1,
},
}
)
return cfg_mock


class TestCheckpointsForTaskLevelAttributes(unittest.TestCase):
def setUp(self) -> None:
self.cfg_mock = get_mock_cfg(None)
self.patches = {
"os.makedirs": MagicMock(),
"os.path.join": MagicMock(),
"os.path.isfile": MagicMock(return_value=True),
"os.path.isabs": MagicMock(return_value=False),
"fairseq.file_io.PathManager.exists": MagicMock(return_value=False),
}
self.applied_patches = [patch(p, d) for p, d in self.patches.items()]
[p.start() for p in self.applied_patches]
logging.disable(logging.CRITICAL)

self.trainer, self.epoch_itr = get_trainer_and_epoch_itr(2, 150, 200, 50)
self.trainer.get_train_iterator = MagicMock(return_value=self.epoch_itr)
self.epoch_itr.next_epoch_itr(shuffle=False)

checkpoint_utils.save_checkpoint(
self.cfg_mock.checkpoint, self.trainer, self.epoch_itr, None
)

def tearDown(self):
patch.stopall()
logging.disable(logging.NOTSET)

def test_verify_checkpoint(self) -> None:
cp_dict = self.trainer.task.get_checkpoint_dict()
self.assertTrue(len(cp_dict) == 1)
self.assertTrue("supernet" in cp_dict)
self.assertTrue("observer_stats" in cp_dict["supernet"])
self.assertTrue(len(cp_dict["supernet"]["observer_stats"]) == 1)
self.assertTrue(
(
4,
16,
"MovingAveragePerChannelMinMax",
"MovingAveragePerChannelMinMax",
)
in cp_dict["supernet"]["observer_stats"]
)
self.assertTrue(
cp_dict["supernet"]["observer_stats"][
(
4,
16,
"MovingAveragePerChannelMinMax",
"MovingAveragePerChannelMinMax",
)
]
== {"mod1": 1, "mod2": 2, "mod3": 3}
)

def test_load_checkpoint(self) -> None:
with contextlib.redirect_stdout(StringIO()):
# Now, load checkpoint to ensure the respective logic works as expected
_, epoch_itr = checkpoint_utils.load_checkpoint(
self.cfg_mock.checkpoint, self.trainer
)

self.trainer.task.set_checkpoint_dict.assert_called_once_with(
checkpoint_dict()["supernet"]
)


if __name__ == "__main__":
unittest.main()

Loading