-
Notifications
You must be signed in to change notification settings - Fork 254
/
plot_utils.py
109 lines (84 loc) · 2.69 KB
/
plot_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import math
import matplotlib.pyplot as plt
import os
import subprocess
import tempfile
def get_fig_size(fig_width_cm, fig_height_cm=None):
"""Convert dimensions in centimeters to inches.
If no height is given, it is computed using the golden ratio.
"""
if not fig_height_cm:
golden_ratio = (1 + math.sqrt(5))/2
fig_height_cm = fig_width_cm / golden_ratio
size_cm = (fig_width_cm, fig_height_cm)
return map(lambda x: x/2.54, size_cm)
"""
The following functions can be used by scripts to get the sizes of
the various elements of the figures.
"""
def label_size():
"""Size of axis labels
"""
return 10
def font_size():
"""Size of all texts shown in plots
"""
return 10
def ticks_size():
"""Size of axes' ticks
"""
return 8
def axis_lw():
"""Line width of the axes
"""
return 0.6
def plot_lw():
"""Line width of the plotted curves
"""
return 1.5
def figure_setup():
"""Set all the sizes to the correct values and use
tex fonts for all texts.
"""
params = {'text.usetex': True,
'figure.dpi': 200,
'font.size': font_size(),
'font.serif': [],
'font.sans-serif': [],
'font.monospace': [],
'axes.labelsize': label_size(),
'axes.titlesize': font_size(),
'axes.linewidth': axis_lw(),
'text.fontsize': font_size(),
'legend.fontsize': font_size(),
'xtick.labelsize': ticks_size(),
'ytick.labelsize': ticks_size(),
'font.family': 'serif'}
plt.rcParams.update(params)
def save_fig(fig, file_name, fmt=None, dpi=300, tight=True):
"""Save a Matplotlib figure as EPS/PNG/PDF to the given path and trim it.
"""
if not fmt:
fmt = file_name.strip().split('.')[-1]
if fmt not in ['eps', 'png', 'pdf']:
raise ValueError('unsupported format: %s' % (fmt,))
extension = '.%s' % (fmt,)
if not file_name.endswith(extension):
file_name += extension
file_name = os.path.abspath(file_name)
with tempfile.NamedTemporaryFile() as tmp_file:
tmp_name = tmp_file.name + extension
# save figure
if tight:
fig.savefig(tmp_name, dpi=dpi, bbox_inches='tight')
else:
fig.savefig(tmp_name, dpi=dpi)
# trim it
if fmt == 'eps':
subprocess.call('epstool --bbox --copy %s %s' %
(tmp_name, file_name), shell=True)
elif fmt == 'png':
subprocess.call('convert %s -trim %s' %
(tmp_name, file_name), shell=True)
elif fmt == 'pdf':
subprocess.call('pdfcrop %s %s' % (tmp_name, file_name), shell=True)