forked from abigailhayes/rl-for-vrp
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplots_instances.py
186 lines (153 loc) · 6.27 KB
/
plots_instances.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import vrplib
import json
import os
import argparse
import pandas as pd
def parse_plots():
"""Parse arguments for plot generation
expt, expt_ids: list, instance_name, instance_set, demand=False"""
parser = argparse.ArgumentParser(description="Plot arguments")
parser.add_argument("--expt", help="Experiment - a, b or c")
parser.add_argument(
"--expt_ids", default=[], help="Experiment ids to plot", type=str, nargs="*"
)
parser.add_argument("--instance_name", help="Instance name including file type")
parser.add_argument("--instance_set", help="Folder containing the instance")
parser.add_argument("--demand", default=False, help="Whether demand is on plots")
args, unknown = parser.parse_known_args()
args = vars(args)
return args
def plot_solution(instance, solution, name, experiment_id, expt_desc, demand=False):
"""
Plot the routes of the passed-in solution.
Adapted from https://alns.readthedocs.io/en/stable/examples/capacitated_vehicle_routing_problem.html
"""
fig, ax = plt.subplots(figsize=(12, 10))
cmap = matplotlib.cm.rainbow(np.linspace(0, 1, len(solution["routes"])))
for idx, route in enumerate(solution["routes"]):
ax.plot(
[instance["node_coord"][loc][0] for loc in [0] + route + [0]],
[instance["node_coord"][loc][1] for loc in [0] + route + [0]],
color=cmap[idx],
marker=".",
)
# Plot the depot
kwargs = dict(s=250)
ax.scatter(
instance["node_coord"][0][0],
instance["node_coord"][0][1],
c="tab:red",
**kwargs,
)
ax.set_title(f"{name}: {expt_desc}\n Total distance: {int(solution['cost'])}")
if demand:
for n, [xi, yi] in enumerate(instance["node_coord"][1:]):
plt.text(xi, yi, instance["demand"][n], va="bottom", ha="center")
plt.tight_layout()
plt.savefig(f"analysis/plots/{name}/expt_{experiment_id}.jpg")
plt.close(fig)
def plot_instance(instance, name, demand=False):
"""
Plot the nodes of the passed-in instance.
"""
fig, ax = plt.subplots(figsize=(6, 5))
plt.rcParams.update({"font.size": 16})
cmap = matplotlib.cm.rainbow(np.linspace(0, 1, 1))
ax.scatter(
[instance["node_coord"][loc][0] for loc in range(1, instance["dimension"])],
[instance["node_coord"][loc][1] for loc in range(1, instance["dimension"])],
color=cmap[0],
)
# Plot the depot
kwargs = dict(s=250)
ax.scatter(
instance["node_coord"][0][0],
instance["node_coord"][0][1],
c="tab:red",
**kwargs,
)
ax.set_title(f"{name}\n Customers: {instance['dimension'] - 1}")
if demand:
for n, [xi, yi] in enumerate(instance["node_coord"][1:]):
plt.text(xi, yi, instance["demand"][n], va="bottom", ha="center")
plt.savefig(f"analysis/plots/{name}/instance.jpg")
def gen_expt_desc(settings_df, expt_id):
"""Generating the description of the method"""
row = settings_df[settings_df["id"] == int(expt_id)].iloc[0]
if row["method"] == "ortools" and pd.isna(row["improve_method"]):
return row["init_method"].replace("_", " ").title()
elif row["method"] == "ortools":
return (
row["init_method"].replace("_", " ").title()
+ " & "
+ row["improve_method"].replace("_", " ").title()
)
elif row["method"] == "rl4co":
return row["init_method"].upper() + " " + str(row["customers"])
elif row["method"] == "rl4co_tsp":
return row["init_method"].upper() + " TSP " + str(row["customers"])
elif row["method"] == "rl4co_mini":
return row["init_method"].upper() + " " + str(row["customers"]) + " Mini"
elif row["method"] == "nazari":
return "Nazari " + str(row["customers"])
def generate_plots(expt, expt_ids: list, instance_name, instance_set, demand=False):
"""Generates a set of plots for a particular instance
- expt - the id for the overall experiment e.g. a or b
- expt_ids - the experiment run solutions to include
- instance_name - the specific name of the instance
- instance_set - the set the instance belongs to i.e. the next level folder
- demand - whether to include the demand of each node in the plot"""
# Remove .vrp
short_name = instance_name.replace(".vrp", "")
# Create an instance specific folder
if not os.path.exists(f"analysis/plots/{short_name}"):
os.makedirs(f"analysis/plots/{short_name}")
# Make sure the folder path is complete
if expt == "b":
instance_folder = f"generate/{instance_set}"
elif expt == "a":
instance_folder = instance_set
# Plot the instance without routes, and the optimal solution if available
instance = vrplib.read_instance(f"instances/CVRP/{instance_folder}/{instance_name}")
plot_instance(instance, short_name, demand)
if expt == "a":
solution = vrplib.read_solution(
f"instances/CVRP/{instance_folder}/{short_name}.sol"
)
plot_solution(instance, solution, short_name, "optimum", "Optimum", demand)
# Load settings
settings_df = pd.read_csv("results/other/settings.csv")
# A solution plot for each experiment id
for expt_id in expt_ids:
print("Expt: ", expt_id)
try:
with open(f"results/exp_{expt_id}/routes_{expt}.json") as json_data:
route_file = json.load(json_data)
with open(f"results/exp_{expt_id}/results_{expt}.json") as json_data:
cost_file = json.load(json_data)
solution = {
"routes": [
route
for route in route_file[instance_set][instance_name]
if len(route) > 0
],
"cost": cost_file[instance_set][instance_name],
}
expt_desc = gen_expt_desc(settings_df, expt_id)
plot_solution(instance, solution, short_name, expt_id, expt_desc, demand)
except (OSError, KeyError):
pass
def main():
args = parse_plots()
generate_plots(
args["expt"],
args["expt_ids"],
args["instance_name"],
args["instance_set"],
args["demand"],
)
if __name__ == "__main__":
main()