-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcreate_figures.py
65 lines (56 loc) · 2.04 KB
/
create_figures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
import argparse
plt.rcParams["font.family"] = "Times New Roman"
MODELS = {
"pretrain_vgg": ("Pretrained VGG", "#7570b3"),
"finetune_vgg": ("Finetuned VGG", "#7570b3"),
"pretrain_resnet": ("Pretrained ResNet", "#1b9e77"),
"finetune_resnet": ("Finetuned ResNet", "#1b9e77"),
"pretrain_autoencoder": ("Pretrained Autoencoder", "#d95f02"),
"finetune_autoencoder": ("Finetuned Autoencoder", "#d95f02"),
}
DATASETS = {"bio": "BINDER Test", "real": "PUBPEER", "mfnd": "MFND IND"}
def plot_roc(args):
files = list(Path(args.path).iterdir())
files.sort()
fig = plt.figure(figsize=(16, 4))
num_figs = len(DATASETS)
for i, (key, name) in enumerate(DATASETS.items()):
ax = fig.add_subplot(1, num_figs, i + 1)
for file in files:
if key in file.name:
fpr, tpr = np.load(file)
label = file.stem.split("_", 1)[1]
if label in MODELS:
if not "pretrain" in label:
ax.plot(
fpr,
tpr,
label=MODELS[label][0],
color=MODELS[label][1],
linestyle=":",
)
else:
ax.plot(
fpr, tpr, label=MODELS[label][0], color=MODELS[label][1]
)
if key == "mfnd":
ax.set_xscale("log")
else:
# reference line
ax.plot([0, 1], [0, 1], linestyle="--", color="k")
ax.set_title(name)
ax.set_xlabel("False Positive Rate")
ax.set_ylabel("True Positive Rate")
ax.legend()
plt.savefig("figures/roc.eps", format="eps", bbox_inches="tight", pad_inches=0.1)
plt.show()
def main(args):
plot_roc(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("path")
args = parser.parse_args()
main(args)