forked from abigailhayes/rl-for-vrp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
am_epochs.py
102 lines (92 loc) · 2.64 KB
/
am_epochs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import methods.rl4co_run as rl4co_run
import utils
from rl4co.utils import RL4COTrainer
from lightning.pytorch.callbacks import ModelCheckpoint, RichModelSummary
import os
import json
args = utils.parse_experiment()
ident = 0
experiment_dir = f"results/exp_{str(ident)}"
if not os.path.exists(experiment_dir):
os.makedirs(experiment_dir)
test_results = {}
test_routes = {}
model = rl4co_run.RL4CO(
args["problem"],
args["method_settings"]["init_method"],
args["method_settings"]["customers"],
args["seed"],
ident,
"greedy",
)
model.set_model()
if args["problem"] == "CVRP":
test_results[0], test_routes[0] = utils.test_cvrp(
args["method"],
args["method_settings"],
ident,
args["testing"],
model,
save=False,
)
else:
test_results[0], test_routes[0] = utils.test_cvrptw(
args["method"],
args["method_settings"],
ident,
args["testing"],
model,
save=False,
)
# Checkpointing callback: save models when validation reward improves
checkpoint_callback = ModelCheckpoint(dirpath="./checkpoints/last.ckpt")
# Print model summary
rich_model_summary = RichModelSummary(max_depth=3)
# Callbacks list
callbacks = [checkpoint_callback, rich_model_summary]
for epoch in [1, 2, 5, 10]:
trainer_kwargs = {
"accelerator": "auto",
"default_root_dir": f"results/am_epochs",
}
model = rl4co_run.RL4CO(
args["problem"],
args["method_settings"]["init_method"],
args["method_settings"]["customers"],
args["seed"],
ident,
"greedy",
)
model.set_model()
model.trainer = RL4COTrainer(
max_epochs=epoch, **trainer_kwargs, callbacks=callbacks
)
model.trainer.fit(model.model, ckpt_path="last")
if args["problem"] == 'CVRP':
test_results[epoch], test_routes[epoch] = utils.test_cvrp(
args["method"],
args["method_settings"],
ident,
args["testing"],
model,
save=False,
)
else:
test_results[epoch], test_routes[epoch] = utils.test_cvrptw(
args["method"],
args["method_settings"],
ident,
args["testing"],
model,
save=False,
)
with open(
f"results/am_epochs/results_{args['method_settings']['init_method']}_{args['method_settings']['customers']}.json",
"w",
) as f:
json.dump(test_results, f, indent=2)
with open(
f"results/am_epochs/routes_{args['method_settings']['init_method']}_{args['method_settings']['customers']}.json",
"w",
) as f:
json.dump(test_routes, f, indent=2)