-
Notifications
You must be signed in to change notification settings - Fork 21
/
utils.py
36 lines (33 loc) · 1.29 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import re
def get_last_checkpoint_or_last_model(folder):
"""modification of get_last_checkpoint from transformer.trainer_utils.
This function will return the main folder if it contains files of the form "pytorch_model*". The default HF function ignores those and only looks
for "checkpoint-*" folders."""
PREFIX_CHECKPOINT_DIR = "checkpoint"
_re_checkpoint = re.compile(r"^" + PREFIX_CHECKPOINT_DIR + r"\-(\d+)$")
_re_model = re.compile("pytorch_model" + r"*")
content = os.listdir(folder)
models = [
path for path in content if _re_model.search(path) is not None
]
if models != []:
return folder
else:
checkpoints = [
path
for path in content
if _re_checkpoint.search(path) is not None and os.path.isdir(os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(folder, max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])))
def parse_checkpoint_step(checkpoint):
if checkpoint.split("-")[0]!= "checkpoint":
return -1
else:
try:
return int(checkpoint.split("-")[-1])
except:
print(f"got checkpoint name {checkpoint}, couldn't parse step")
return -1