-
Notifications
You must be signed in to change notification settings - Fork 0
/
checkpointManager.py
58 lines (57 loc) · 1.96 KB
/
checkpointManager.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os, torch
class CheckpointManager:
"""
checkpoint standard:
{
'global_step': int,
'global_epoch': int,
'block_size':int,
'vocab_size':int,
'n_layer':int,
'n_head':int,
'n_embd':int,
'bias':bool,
'dataset': str,
'state_dict': dict,
'optimizer': dict,
}
"""
def __init__(self, save_root):
self.save_root = save_root
if not os.path.exists(save_root):
os.makedirs(save_root)
def load(self, path, model, optimizer=None, compile=True):
print(f'Loading State Dict from {path}')
checkpoint = torch.load(path)
state_dict = checkpoint['state_dict']
if not compile:
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}
model.load_state_dict(state_dict)
if optimizer is not None:
optimizer.load_state_dict(checkpoint['optimizer'])
global_step = checkpoint['global_step']
global_epoch = checkpoint['global_epoch']
print('state dict config:')
for k, v in checkpoint.items():
if k == 'state_dict' or k == 'optimizer':
continue
print(f'{k}={v}')
print('-'*6)
return model, global_step, global_epoch
def save(self, model, config, global_step, global_epoch, optimizer):
filename = f'cpt{global_step}.pth'
path = os.path.join(self.save_root, filename)
checkpoint = {
'global_step': global_step,
'global_epoch': global_epoch,
'block_size': config.block_size,
'vocab_size': config.vocab_size,
'n_layer': config.n_layer,
'n_head': config.n_head,
'n_embd': config.n_embd,
'bias': config.bias,
'dataset': config.data_path,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint, path)